diff --git a/common/build.gradle b/common/build.gradle index 109cad59cb..416c9ca20a 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -35,7 +35,7 @@ repositories { dependencies { api "org.antlr:antlr4-runtime:4.7.1" api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - api group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.20.0' + api group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' implementation 'com.github.babbel:okhttp-aws-signer:1.0.2' diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 8daf0e9bf6..ae1950d81c 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -5,6 +5,8 @@ package org.opensearch.sql.common.setting; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; + import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -36,7 +38,9 @@ public enum Key { METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), - CLUSTER_NAME("cluster.name"); + CLUSTER_NAME("cluster.name"), + SPARK_EXECUTION_SESSION_ENABLED("plugins.query.executionengine.spark.session.enabled"), + SPARK_EXECUTION_SESSION_LIMIT("plugins.query.executionengine.spark.session.limit"); @Getter private final String keyValue; @@ -60,4 +64,9 @@ public static Optional of(String keyValue) { public abstract T getSettingValue(Key key); public abstract List getSettings(); + + /** Helper class */ + public static boolean isSparkExecutionSessionEnabled(Settings settings) { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } } diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 6dace50f99..162fe9e8f8 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -5,6 +5,7 @@ package org.opensearch.sql.datasource; +import java.util.Map; import java.util.Set; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; @@ -56,12 +57,19 @@ public interface DataSourceService { void createDataSource(DataSourceMetadata metadata); /** - * Updates {@link DataSource} corresponding to dataSourceMetadata. + * Updates {@link DataSource} corresponding to dataSourceMetadata (all fields needed). * * @param dataSourceMetadata {@link DataSourceMetadata}. */ void updateDataSource(DataSourceMetadata dataSourceMetadata); + /** + * Patches {@link DataSource} corresponding to the given name (only fields to be changed needed). + * + * @param dataSourceData + */ + void patchDataSource(Map dataSourceData); + /** * Deletes {@link DataSource} corresponding to the DataSource name. * diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index 508567582b..569cdd96f8 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -231,6 +231,9 @@ public DataSource getDataSource(String dataSourceName) { @Override public void updateDataSource(DataSourceMetadata dataSourceMetadata) {} + @Override + public void patchDataSource(Map dataSourceData) {} + @Override public void deleteDataSource(String dataSourceName) {} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionRequest.java b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionRequest.java new file mode 100644 index 0000000000..9443ea561e --- /dev/null +++ b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionRequest.java @@ -0,0 +1,49 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.datasources.model.transport; + +import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.CONNECTOR_FIELD; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.NAME_FIELD; + +import java.io.IOException; +import java.util.Map; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; + +public class PatchDataSourceActionRequest extends ActionRequest { + + @Getter private Map dataSourceData; + + /** Constructor of UpdateDataSourceActionRequest from StreamInput. */ + public PatchDataSourceActionRequest(StreamInput in) throws IOException { + super(in); + } + + public PatchDataSourceActionRequest(Map dataSourceData) { + this.dataSourceData = dataSourceData; + } + + @Override + public ActionRequestValidationException validate() { + if (this.dataSourceData.get(NAME_FIELD).equals(DEFAULT_DATASOURCE_NAME)) { + ActionRequestValidationException exception = new ActionRequestValidationException(); + exception.addValidationError( + "Not allowed to update datasource with name : " + DEFAULT_DATASOURCE_NAME); + return exception; + } else if (this.dataSourceData.get(CONNECTOR_FIELD) != null) { + ActionRequestValidationException exception = new ActionRequestValidationException(); + exception.addValidationError("Not allowed to update connector for datasource"); + return exception; + } else { + return null; + } + } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionResponse.java b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionResponse.java new file mode 100644 index 0000000000..18873a6731 --- /dev/null +++ b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionResponse.java @@ -0,0 +1,31 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.datasources.model.transport; + +import java.io.IOException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +@RequiredArgsConstructor +public class PatchDataSourceActionResponse extends ActionResponse { + + @Getter private final String result; + + public PatchDataSourceActionResponse(StreamInput in) throws IOException { + super(in); + result = in.readString(); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeString(result); + } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index 2947afc5b9..c207f55738 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -10,15 +10,13 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.NOT_FOUND; import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; -import static org.opensearch.rest.RestRequest.Method.DELETE; -import static org.opensearch.rest.RestRequest.Method.GET; -import static org.opensearch.rest.RestRequest.Method.POST; -import static org.opensearch.rest.RestRequest.Method.PUT; +import static org.opensearch.rest.RestRequest.Method.*; import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -32,18 +30,8 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.datasources.exceptions.ErrorMessage; -import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.DeleteDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.DeleteDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.GetDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.GetDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.UpdateDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.UpdateDataSourceActionResponse; -import org.opensearch.sql.datasources.transport.TransportCreateDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportDeleteDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportGetDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportUpdateDataSourceAction; +import org.opensearch.sql.datasources.model.transport.*; +import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.datasources.utils.Scheduler; import org.opensearch.sql.datasources.utils.XContentParserUtils; @@ -98,6 +86,17 @@ public List routes() { */ new Route(PUT, BASE_DATASOURCE_ACTION_URL), + /* + * PATCH datasources + * Request body: + * Ref + * [org.opensearch.sql.plugin.transport.datasource.model.PatchDataSourceActionRequest] + * Response body: + * Ref + * [org.opensearch.sql.plugin.transport.datasource.model.PatchDataSourceActionResponse] + */ + new Route(PATCH, BASE_DATASOURCE_ACTION_URL), + /* * DELETE datasources * Request body: Ref @@ -122,6 +121,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient return executeUpdateRequest(restRequest, nodeClient); case DELETE: return executeDeleteRequest(restRequest, nodeClient); + case PATCH: + return executePatchRequest(restRequest, nodeClient); default: return restChannel -> restChannel.sendResponse( @@ -216,6 +217,34 @@ public void onFailure(Exception e) { })); } + private RestChannelConsumer executePatchRequest(RestRequest restRequest, NodeClient nodeClient) + throws IOException { + Map dataSourceData = XContentParserUtils.toMap(restRequest.contentParser()); + return restChannel -> + Scheduler.schedule( + nodeClient, + () -> + nodeClient.execute( + TransportPatchDataSourceAction.ACTION_TYPE, + new PatchDataSourceActionRequest(dataSourceData), + new ActionListener<>() { + @Override + public void onResponse( + PatchDataSourceActionResponse patchDataSourceActionResponse) { + restChannel.sendResponse( + new BytesRestResponse( + RestStatus.OK, + "application/json; charset=UTF-8", + patchDataSourceActionResponse.getResult())); + } + + @Override + public void onFailure(Exception e) { + handleException(e, restChannel); + } + })); + } + private RestChannelConsumer executeDeleteRequest(RestRequest restRequest, NodeClient nodeClient) { String dataSourceName = restRequest.param("dataSourceName"); diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 25e8006d66..8ba618fb44 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -6,15 +6,11 @@ package org.opensearch.sql.datasources.service; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; import com.google.common.base.Preconditions; import com.google.common.base.Strings; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; +import java.util.*; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSource; @@ -100,6 +96,19 @@ public void updateDataSource(DataSourceMetadata dataSourceMetadata) { } } + @Override + public void patchDataSource(Map dataSourceData) { + if (!dataSourceData.get(NAME_FIELD).equals(DEFAULT_DATASOURCE_NAME)) { + DataSourceMetadata dataSourceMetadata = + getRawDataSourceMetadata((String) dataSourceData.get(NAME_FIELD)); + replaceOldDatasourceMetadata(dataSourceData, dataSourceMetadata); + updateDataSource(dataSourceMetadata); + } else { + throw new UnsupportedOperationException( + "Not allowed to update default datasource :" + DEFAULT_DATASOURCE_NAME); + } + } + @Override public void deleteDataSource(String dataSourceName) { if (dataSourceName.equals(DEFAULT_DATASOURCE_NAME)) { @@ -136,6 +145,35 @@ private void validateDataSourceMetaData(DataSourceMetadata metadata) { + " Properties are required parameters."); } + /** + * Replaces the fields in the map of the given metadata. + * + * @param dataSourceData + * @param metadata {@link DataSourceMetadata}. + */ + private void replaceOldDatasourceMetadata( + Map dataSourceData, DataSourceMetadata metadata) { + + for (String key : dataSourceData.keySet()) { + switch (key) { + // Name and connector should not be modified + case DESCRIPTION_FIELD: + metadata.setDescription((String) dataSourceData.get(DESCRIPTION_FIELD)); + break; + case ALLOWED_ROLES_FIELD: + metadata.setAllowedRoles((List) dataSourceData.get(ALLOWED_ROLES_FIELD)); + break; + case PROPERTIES_FIELD: + Map properties = new HashMap<>(metadata.getProperties()); + properties.putAll(((Map) dataSourceData.get(PROPERTIES_FIELD))); + break; + case NAME_FIELD: + case CONNECTOR_FIELD: + break; + } + } + } + @Override public DataSourceMetadata getRawDataSourceMetadata(String dataSourceName) { if (dataSourceName.equals(DEFAULT_DATASOURCE_NAME)) { diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceAction.java new file mode 100644 index 0000000000..303e905cec --- /dev/null +++ b/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceAction.java @@ -0,0 +1,74 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.datasources.transport; + +import static org.opensearch.sql.datasources.utils.XContentParserUtils.NAME_FIELD; +import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionRequest; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionResponse; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportPatchDataSourceAction + extends HandledTransportAction { + + public static final String NAME = "cluster:admin/opensearch/ql/datasources/patch"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, PatchDataSourceActionResponse::new); + + private DataSourceService dataSourceService; + + /** + * TransportPatchDataSourceAction action for updating datasource. + * + * @param transportService transportService. + * @param actionFilters actionFilters. + * @param dataSourceService dataSourceService. + */ + @Inject + public TransportPatchDataSourceAction( + TransportService transportService, + ActionFilters actionFilters, + DataSourceServiceImpl dataSourceService) { + super( + TransportPatchDataSourceAction.NAME, + transportService, + actionFilters, + PatchDataSourceActionRequest::new); + this.dataSourceService = dataSourceService; + } + + @Override + protected void doExecute( + Task task, + PatchDataSourceActionRequest request, + ActionListener actionListener) { + try { + dataSourceService.patchDataSource(request.getDataSourceData()); + String responseContent = + new JsonResponseFormatter(PRETTY) { + @Override + protected Object buildJsonObject(String response) { + return response; + } + }.format("Updated DataSource with name " + request.getDataSourceData().get(NAME_FIELD)); + actionListener.onResponse(new PatchDataSourceActionResponse(responseContent)); + } catch (Exception e) { + actionListener.onFailure(e); + } + } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java index 261f13870a..6af2a5a761 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java @@ -90,6 +90,59 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr name, description, connector, allowedRoles, properties, resultIndex); } + public static Map toMap(XContentParser parser) throws IOException { + Map resultMap = new HashMap<>(); + String name; + String description; + List allowedRoles = new ArrayList<>(); + Map properties = new HashMap<>(); + String resultIndex; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case NAME_FIELD: + name = parser.textOrNull(); + resultMap.put(NAME_FIELD, name); + break; + case DESCRIPTION_FIELD: + description = parser.textOrNull(); + resultMap.put(DESCRIPTION_FIELD, description); + break; + case CONNECTOR_FIELD: + // no-op - datasource connector should not be modified + break; + case ALLOWED_ROLES_FIELD: + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + allowedRoles.add(parser.text()); + } + resultMap.put(ALLOWED_ROLES_FIELD, allowedRoles); + break; + case PROPERTIES_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String key = parser.currentName(); + parser.nextToken(); + String value = parser.textOrNull(); + properties.put(key, value); + } + resultMap.put(PROPERTIES_FIELD, properties); + break; + case RESULT_INDEX_FIELD: + resultIndex = parser.textOrNull(); + resultMap.put(RESULT_INDEX_FIELD, resultIndex); + break; + default: + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + } + if (resultMap.get(NAME_FIELD) == null || resultMap.get(NAME_FIELD) == "") { + throw new IllegalArgumentException("Name is a required field."); + } + return resultMap; + } + /** * Converts json string to DataSourceMetadata. * @@ -109,6 +162,25 @@ public static DataSourceMetadata toDataSourceMetadata(String json) throws IOExce } } + /** + * Converts json string to Map. + * + * @param json jsonstring. + * @return DataSourceData + * @throws IOException IOException. + */ + public static Map toMap(String json) throws IOException { + try (XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + json)) { + return toMap(parser); + } + } + /** * Converts DataSourceMetadata to XContentBuilder. * diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index 6164d8b73f..c62e586dae 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -18,6 +18,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; @@ -264,7 +265,6 @@ void testGetDataSourceMetadataSetWithDefaultDatasource() { @Test void testUpdateDataSourceSuccessCase() { - DataSourceMetadata dataSourceMetadata = metadata("testDS", DataSourceType.OPENSEARCH, Collections.emptyList(), ImmutableMap.of()); dataSourceService.updateDataSource(dataSourceMetadata); @@ -289,6 +289,46 @@ void testUpdateDefaultDataSource() { unsupportedOperationException.getMessage()); } + @Test + void testPatchDefaultDataSource() { + Map dataSourceData = + Map.of(NAME_FIELD, DEFAULT_DATASOURCE_NAME, DESCRIPTION_FIELD, "test"); + UnsupportedOperationException unsupportedOperationException = + assertThrows( + UnsupportedOperationException.class, + () -> dataSourceService.patchDataSource(dataSourceData)); + assertEquals( + "Not allowed to update default datasource :" + DEFAULT_DATASOURCE_NAME, + unsupportedOperationException.getMessage()); + } + + @Test + void testPatchDataSourceSuccessCase() { + // Tests that patch underlying implementation is to call update + Map dataSourceData = + new HashMap<>( + Map.of( + NAME_FIELD, + "testDS", + DESCRIPTION_FIELD, + "test", + CONNECTOR_FIELD, + "PROMETHEUS", + ALLOWED_ROLES_FIELD, + new ArrayList<>(), + PROPERTIES_FIELD, + Map.of(), + RESULT_INDEX_FIELD, + "")); + DataSourceMetadata getData = + metadata("testDS", DataSourceType.OPENSEARCH, Collections.emptyList(), ImmutableMap.of()); + when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) + .thenReturn(Optional.ofNullable(getData)); + + dataSourceService.patchDataSource(dataSourceData); + verify(dataSourceMetadataStorage, times(1)).updateDataSourceMetadata(any()); + } + @Test void testDeleteDatasource() { dataSourceService.deleteDataSource("testDS"); diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceActionTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceActionTest.java new file mode 100644 index 0000000000..5e1e7df112 --- /dev/null +++ b/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceActionTest.java @@ -0,0 +1,79 @@ +package org.opensearch.sql.datasources.transport; + +import static org.mockito.Mockito.*; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionRequest; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionResponse; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportPatchDataSourceActionTest { + + @Mock private TransportService transportService; + @Mock private TransportPatchDataSourceAction action; + @Mock private DataSourceServiceImpl dataSourceService; + @Mock private Task task; + @Mock private ActionListener actionListener; + + @Captor + private ArgumentCaptor patchDataSourceActionResponseArgumentCaptor; + + @Captor private ArgumentCaptor exceptionArgumentCaptor; + + @BeforeEach + public void setUp() { + action = + new TransportPatchDataSourceAction( + transportService, new ActionFilters(new HashSet<>()), dataSourceService); + } + + @Test + public void testDoExecute() { + Map dataSourceData = new HashMap<>(); + dataSourceData.put(NAME_FIELD, "test_datasource"); + dataSourceData.put(DESCRIPTION_FIELD, "test"); + + PatchDataSourceActionRequest request = new PatchDataSourceActionRequest(dataSourceData); + + action.doExecute(task, request, actionListener); + verify(dataSourceService, times(1)).patchDataSource(dataSourceData); + Mockito.verify(actionListener) + .onResponse(patchDataSourceActionResponseArgumentCaptor.capture()); + PatchDataSourceActionResponse patchDataSourceActionResponse = + patchDataSourceActionResponseArgumentCaptor.getValue(); + String responseAsJson = "\"Updated DataSource with name test_datasource\""; + Assertions.assertEquals(responseAsJson, patchDataSourceActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + Map dataSourceData = new HashMap<>(); + dataSourceData.put(NAME_FIELD, "test_datasource"); + dataSourceData.put(DESCRIPTION_FIELD, "test"); + doThrow(new RuntimeException("Error")).when(dataSourceService).patchDataSource(dataSourceData); + PatchDataSourceActionRequest request = new PatchDataSourceActionRequest(dataSourceData); + action.doExecute(task, request, actionListener); + verify(dataSourceService, times(1)).patchDataSource(dataSourceData); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("Error", exception.getMessage()); + } +} diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java index d134293456..e1e442d12b 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java @@ -1,6 +1,7 @@ package org.opensearch.sql.datasources.utils; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; import com.google.gson.Gson; import java.util.HashMap; @@ -52,6 +53,46 @@ public void testToDataSourceMetadataFromJson() { Assertions.assertEquals("prometheus_access", retrievedMetadata.getAllowedRoles().get(0)); } + @SneakyThrows + @Test + public void testToMapFromJson() { + Map dataSourceData = + Map.of( + NAME_FIELD, + "test_DS", + DESCRIPTION_FIELD, + "test", + ALLOWED_ROLES_FIELD, + List.of("all_access"), + PROPERTIES_FIELD, + Map.of("prometheus.uri", "localhost:9090"), + CONNECTOR_FIELD, + "PROMETHEUS", + RESULT_INDEX_FIELD, + ""); + + Map dataSourceDataConnectorRemoved = + Map.of( + NAME_FIELD, + "test_DS", + DESCRIPTION_FIELD, + "test", + ALLOWED_ROLES_FIELD, + List.of("all_access"), + PROPERTIES_FIELD, + Map.of("prometheus.uri", "localhost:9090"), + RESULT_INDEX_FIELD, + ""); + + Gson gson = new Gson(); + String json = gson.toJson(dataSourceData); + + Map parsedData = XContentParserUtils.toMap(json); + + Assertions.assertEquals(parsedData, dataSourceDataConnectorRemoved); + Assertions.assertEquals("test", parsedData.get(DESCRIPTION_FIELD)); + } + @SneakyThrows @Test public void testToDataSourceMetadataFromJsonWithoutName() { @@ -71,6 +112,22 @@ public void testToDataSourceMetadataFromJsonWithoutName() { Assertions.assertEquals("name and connector are required fields.", exception.getMessage()); } + @SneakyThrows + @Test + public void testToMapFromJsonWithoutName() { + Map dataSourceData = new HashMap<>(Map.of(DESCRIPTION_FIELD, "test")); + Gson gson = new Gson(); + String json = gson.toJson(dataSourceData); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + XContentParserUtils.toMap(json); + }); + Assertions.assertEquals("Name is a required field.", exception.getMessage()); + } + @SneakyThrows @Test public void testToDataSourceMetadataFromJsonWithoutConnector() { @@ -106,4 +163,21 @@ public void testToDataSourceMetadataFromJsonUsingUnknownObject() { }); Assertions.assertEquals("Unknown field: test", exception.getMessage()); } + + @SneakyThrows + @Test + public void testToMapFromJsonUsingUnknownObject() { + HashMap hashMap = new HashMap<>(); + hashMap.put("test", "test"); + Gson gson = new Gson(); + String json = gson.toJson(hashMap); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + XContentParserUtils.toMap(json); + }); + Assertions.assertEquals("Unknown field: test", exception.getMessage()); + } } diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index b5da4e28e2..686116636a 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -311,3 +311,75 @@ SQL query:: "status": 400 } +plugins.query.executionengine.spark.session.enabled +=================================================== + +Description +----------- + +By default, execution engine is executed in job mode. You can enable session mode by this setting. + +1. The default value is false. +2. This setting is node scope. +3. This setting can be updated dynamically. + +You can update the setting with a new value like this. + +SQL query:: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient":{"plugins.query.executionengine.spark.session.enabled":"true"}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "spark": { + "session": { + "enabled": "true" + } + } + } + } + } + } + } + +plugins.query.executionengine.spark.session.limit +=================================================== + +Description +----------- + +Each datasource can have maximum 100 sessions running in parallel by default. You can increase limit by this setting. + +1. The default value is 100. +2. This setting is node scope. +3. This setting can be updated dynamically. + +You can update the setting with a new value like this. + +SQL query:: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient":{"plugins.query.executionengine.spark.session.limit":200}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "spark": { + "session": { + "limit": "200" + } + } + } + } + } + } + } + diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index a9fc77264c..3fbc16d15f 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -62,6 +62,50 @@ Sample Response:: "queryId": "00fd796ut1a7eg0q" } +Execute query in session +------------------------ + +if plugins.query.executionengine.spark.session.enabled is set to true, session based execution is enabled. Under the hood, all queries submitted to the same session will be executed in the same SparkContext. Session is auto closed if not query submission in 10 minutes. + +Async query response include ``sessionId`` indicate the query is executed in session. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10" + }' + +Sample Response:: + + { + "queryId": "HlbM61kX6MDkAktO", + "sessionId": "1Giy65ZnzNlmsPAm" + } + +User could reuse the session by using ``sessionId`` query parameters. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10", + "sessionId" : "1Giy65ZnzNlmsPAm" + }' + +Sample Response:: + + { + "queryId": "7GC4mHhftiTejvxN", + "sessionId": "1Giy65ZnzNlmsPAm" + } + Async Query Result API ====================================== diff --git a/docs/user/ppl/admin/datasources.rst b/docs/user/ppl/admin/datasources.rst index 3682153b9d..31378f6cc4 100644 --- a/docs/user/ppl/admin/datasources.rst +++ b/docs/user/ppl/admin/datasources.rst @@ -93,6 +93,19 @@ we can remove authorization and other details in case of security disabled domai "allowedRoles" : ["prometheus_access"] } +* Datasource modification PATCH API ("_plugins/_query/_datasources") :: + + PATCH https://localhost:9200/_plugins/_query/_datasources + content-type: application/json + Authorization: Basic {{username}} {{password}} + + { + "name" : "my_prometheus", + "allowedRoles" : ["all_access"] + } + + **Name is required and must exist. Connector cannot be modified and will be ignored.** + * Datasource Read GET API("_plugins/_query/_datasources/{{dataSourceName}}" :: GET https://localhost:9200/_plugins/_query/_datasources/my_prometheus @@ -114,6 +127,7 @@ Each of the datasource configuration management apis are controlled by following * cluster:admin/opensearch/datasources/create [Create POST API] * cluster:admin/opensearch/datasources/read [Get GET API] * cluster:admin/opensearch/datasources/update [Update PUT API] +* cluster:admin/opensearch/datasources/patch [Update PATCH API] * cluster:admin/opensearch/datasources/delete [Delete DELETE API] Only users mapped with roles having above actions are authorized to execute datasource management apis. diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 6925cb9101..85535da68c 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -42,7 +42,7 @@ apply plugin: 'java' apply plugin: 'io.freefair.lombok' apply plugin: 'com.wiredforcode.spawn' -String baseVersion = "2.11.0" +String baseVersion = "2.12.0" String bwcVersion = baseVersion + ".0"; String baseName = "sqlBwcCluster" String bwcFilePath = "src/test/resources/bwc/" @@ -175,7 +175,7 @@ dependencies { testImplementation group: 'org.opensearch.client', name: 'opensearch-rest-client', version: "${opensearch_version}" testImplementation group: 'org.opensearch.driver', name: 'opensearch-sql-jdbc', version: System.getProperty("jdbcDriverVersion", '1.2.0.0') testImplementation group: 'org.hamcrest', name: 'hamcrest', version: '2.1' - implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.20.0' + implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" testImplementation project(':opensearch-sql-plugin') testImplementation project(':legacy') testImplementation('org.junit.jupiter:junit-jupiter-api:5.6.2') diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 8623b9fa6f..ff36d2a887 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -5,6 +5,8 @@ package org.opensearch.sql.datasource; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.DESCRIPTION_FIELD; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.NAME_FIELD; import static org.opensearch.sql.legacy.TestUtils.getResponseBody; import com.google.common.collect.ImmutableList; @@ -15,7 +17,9 @@ import java.io.IOException; import java.lang.reflect.Type; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import lombok.SneakyThrows; import org.junit.AfterClass; import org.junit.Assert; @@ -136,6 +140,31 @@ public void updateDataSourceAPITest() { Assert.assertEquals( "https://randomtest.com:9090", dataSourceMetadata.getProperties().get("prometheus.uri")); Assert.assertEquals("", dataSourceMetadata.getDescription()); + + // patch datasource + Map updateDS = + new HashMap<>(Map.of(NAME_FIELD, "update_prometheus", DESCRIPTION_FIELD, "test")); + Request patchRequest = getPatchDataSourceRequest(updateDS); + Response patchResponse = client().performRequest(patchRequest); + Assert.assertEquals(200, patchResponse.getStatusLine().getStatusCode()); + String patchResponseString = getResponseBody(patchResponse); + Assert.assertEquals("\"Updated DataSource with name update_prometheus\"", patchResponseString); + + // Datasource is not immediately updated. so introducing a sleep of 2s. + Thread.sleep(2000); + + // get datasource to validate the modification. + // get datasource + Request getRequestAfterPatch = getFetchDataSourceRequest("update_prometheus"); + Response getResponseAfterPatch = client().performRequest(getRequestAfterPatch); + Assert.assertEquals(200, getResponseAfterPatch.getStatusLine().getStatusCode()); + String getResponseStringAfterPatch = getResponseBody(getResponseAfterPatch); + DataSourceMetadata dataSourceMetadataAfterPatch = + new Gson().fromJson(getResponseStringAfterPatch, DataSourceMetadata.class); + Assert.assertEquals( + "https://randomtest.com:9090", + dataSourceMetadataAfterPatch.getProperties().get("prometheus.uri")); + Assert.assertEquals("test", dataSourceMetadataAfterPatch.getDescription()); } @SneakyThrows diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java index 8335ada5a7..058182f123 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java @@ -49,6 +49,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Locale; +import java.util.Map; import javax.management.MBeanServerInvocationHandler; import javax.management.ObjectName; import javax.management.remote.JMXConnector; @@ -488,6 +489,15 @@ protected static Request getUpdateDataSourceRequest(DataSourceMetadata dataSourc return request; } + protected static Request getPatchDataSourceRequest(Map dataSourceData) { + Request request = new Request("PATCH", "/_plugins/_query/_datasources"); + request.setJsonEntity(new Gson().toJson(dataSourceData)); + RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); + restOptionsBuilder.addHeader("Content-Type", "application/json"); + request.setOptions(restOptionsBuilder); + return request; + } + protected static Request getFetchDataSourceRequest(String name) { Request request = new Request("GET", "/_plugins/_query/_datasources" + "/" + name); if (StringUtils.isEmpty(name)) { diff --git a/legacy/build.gradle b/legacy/build.gradle index fc985989e5..ca20476610 100644 --- a/legacy/build.gradle +++ b/legacy/build.gradle @@ -108,7 +108,7 @@ dependencies { } } implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - implementation group: 'org.json', name: 'json', version:'20230227' + implementation group: 'org.json', name: 'json', version:'20231013' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" // add geo module as dependency. https://github.com/opensearch-project/OpenSearch/pull/4180/. diff --git a/opensearch/build.gradle b/opensearch/build.gradle index 34b5c3f452..c9087bca49 100644 --- a/opensearch/build.gradle +++ b/opensearch/build.gradle @@ -37,7 +37,7 @@ dependencies { implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "${versions.jackson}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "${versions.jackson_databind}" implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" - implementation group: 'org.json', name: 'json', version:'20230227' + implementation group: 'org.json', name: 'json', version:'20231013' compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}" implementation group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 76bda07607..f80b576fe0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -135,6 +135,20 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_SESSION_ENABLED_SETTING = + Setting.boolSetting( + Key.SPARK_EXECUTION_SESSION_ENABLED.getKeyValue(), + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + + public static final Setting SPARK_EXECUTION_SESSION_LIMIT_SETTING = + Setting.intSetting( + Key.SPARK_EXECUTION_SESSION_LIMIT.getKeyValue(), + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -205,6 +219,18 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.SPARK_EXECUTION_ENGINE_CONFIG, SPARK_EXECUTION_ENGINE_CONFIG, new Updater(Key.SPARK_EXECUTION_ENGINE_CONFIG)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_SESSION_ENABLED, + SPARK_EXECUTION_SESSION_ENABLED_SETTING, + new Updater(Key.SPARK_EXECUTION_SESSION_ENABLED)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_SESSION_LIMIT, + SPARK_EXECUTION_SESSION_LIMIT_SETTING, + new Updater(Key.SPARK_EXECUTION_SESSION_LIMIT)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -270,6 +296,8 @@ public static List> pluginSettings() { .add(METRICS_ROLLING_INTERVAL_SETTING) .add(DATASOURCE_URI_HOSTS_DENY_LIST) .add(SPARK_EXECUTION_ENGINE_CONFIG) + .add(SPARK_EXECUTION_SESSION_ENABLED_SETTING) + .add(SPARK_EXECUTION_SESSION_LIMIT_SETTING) .build(); } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f3fd043b63..3d9740d84c 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -58,18 +58,12 @@ import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.datasources.encryptor.EncryptorImpl; import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; -import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.DeleteDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.GetDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.UpdateDataSourceActionResponse; +import org.opensearch.sql.datasources.model.transport.*; import org.opensearch.sql.datasources.rest.RestDataSourceQueryAction; import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; -import org.opensearch.sql.datasources.transport.TransportCreateDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportDeleteDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportGetDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportUpdateDataSourceAction; +import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.executor.AsyncRestExecutor; import org.opensearch.sql.legacy.metrics.Metrics; @@ -99,6 +93,8 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; @@ -180,6 +176,10 @@ public List getRestHandlers( new ActionType<>( TransportUpdateDataSourceAction.NAME, UpdateDataSourceActionResponse::new), TransportUpdateDataSourceAction.class), + new ActionHandler<>( + new ActionType<>( + TransportPatchDataSourceAction.NAME, PatchDataSourceActionResponse::new), + TransportPatchDataSourceAction.class), new ActionHandler<>( new ActionType<>( TransportDeleteDataSourceAction.NAME, DeleteDataSourceActionResponse::new), @@ -306,8 +306,9 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService( SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, SparkExecutionEngineConfig sparkExecutionEngineConfig) { + StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore); EMRServerlessClient emrServerlessClient = createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); @@ -318,7 +319,8 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new DataSourceUserAuthorizationHelperImpl(client), jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), - client); + client, + new SessionManager(stateStore, emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/ppl/build.gradle b/ppl/build.gradle index a798b3f4b0..75281e9160 100644 --- a/ppl/build.gradle +++ b/ppl/build.gradle @@ -49,8 +49,8 @@ dependencies { implementation "org.antlr:antlr4-runtime:4.7.1" implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - api group: 'org.json', name: 'json', version: '20230227' - implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.20.0' + api group: 'org.json', name: 'json', version: '20231013' + implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api project(':common') api project(':core') api project(':protocol') diff --git a/prometheus/build.gradle b/prometheus/build.gradle index f8c10c7f6b..c2878ab1b4 100644 --- a/prometheus/build.gradle +++ b/prometheus/build.gradle @@ -22,7 +22,7 @@ dependencies { implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "${versions.jackson}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "${versions.jackson_databind}" implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" - implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'org.json', name: 'json', version: '20231013' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' diff --git a/spark/build.gradle b/spark/build.gradle index c06b5b6ecf..8f4388495e 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -47,20 +47,44 @@ dependencies { implementation project(':datasources') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" - implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'org.json', name: 'json', version: '20231013' api group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.545' api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' - testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation(platform("org.junit:junit-bom:5.6.2")) + + testImplementation('org.junit.jupiter:junit-jupiter') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' - testImplementation 'junit:junit:4.13.1' - testImplementation "org.opensearch.test:framework:${opensearch_version}" + + testCompileOnly('junit:junit:4.13.1') { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.vintage:junit-vintage-engine") { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.platform:junit-platform-launcher") { + because 'allows tests to run from IDEs that bundle older version of launcher' + } + testImplementation("org.opensearch.test:framework:${opensearch_version}") + testImplementation project(':opensearch') } test { - useJUnitPlatform() + useJUnitPlatform { + includeEngines("junit-jupiter") + } + testLogging { + events "failed" + exceptionFormat "full" + } +} +task junit4(type: Test) { + useJUnitPlatform { + includeEngines("junit-vintage") + } + systemProperty 'tests.security.manager', 'false' testLogging { events "failed" exceptionFormat "full" @@ -68,6 +92,8 @@ test { } jacocoTestReport { + dependsOn test, junit4 + executionData test, junit4 reports { html.enabled true xml.enabled true @@ -78,9 +104,10 @@ jacocoTestReport { })) } } -test.finalizedBy(project.tasks.jacocoTestReport) jacocoTestCoverageVerification { + dependsOn test, junit4 + executionData test, junit4 violationRules { rule { element = 'CLASS' @@ -92,6 +119,10 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.asyncquery.exceptions.*', 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', + // ignore because XContext IOException + 'org.opensearch.sql.spark.execution.statestore.StateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel', + 'org.opensearch.sql.spark.execution.statement.StatementModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 index e8e0264f28..f48c276e44 100644 --- a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 +++ b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 @@ -17,6 +17,7 @@ singleStatement statement : skippingIndexStatement | coveringIndexStatement + | materializedViewStatement ; skippingIndexStatement @@ -76,6 +77,44 @@ dropCoveringIndexStatement : DROP INDEX indexName ON tableName ; +materializedViewStatement + : createMaterializedViewStatement + | refreshMaterializedViewStatement + | showMaterializedViewStatement + | describeMaterializedViewStatement + | dropMaterializedViewStatement + ; + +createMaterializedViewStatement + : CREATE MATERIALIZED VIEW (IF NOT EXISTS)? mvName=multipartIdentifier + AS query=materializedViewQuery + (WITH LEFT_PAREN propertyList RIGHT_PAREN)? + ; + +refreshMaterializedViewStatement + : REFRESH MATERIALIZED VIEW mvName=multipartIdentifier + ; + +showMaterializedViewStatement + : SHOW MATERIALIZED (VIEW | VIEWS) IN catalogDb=multipartIdentifier + ; + +describeMaterializedViewStatement + : (DESC | DESCRIBE) MATERIALIZED VIEW mvName=multipartIdentifier + ; + +dropMaterializedViewStatement + : DROP MATERIALIZED VIEW mvName=multipartIdentifier + ; + +/* + * Match all remaining tokens in non-greedy way + * so WITH clause won't be captured by this rule. + */ +materializedViewQuery + : .+? + ; + indexColTypeList : indexColType (COMMA indexColType)* ; diff --git a/spark/src/main/antlr/SparkSqlBase.g4 b/spark/src/main/antlr/SparkSqlBase.g4 index 4ac1ced5c4..533d851ba6 100644 --- a/spark/src/main/antlr/SparkSqlBase.g4 +++ b/spark/src/main/antlr/SparkSqlBase.g4 @@ -154,6 +154,7 @@ COMMA: ','; DOT: '.'; +AS: 'AS'; CREATE: 'CREATE'; DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; @@ -161,14 +162,18 @@ DROP: 'DROP'; EXISTS: 'EXISTS'; FALSE: 'FALSE'; IF: 'IF'; +IN: 'IN'; INDEX: 'INDEX'; INDEXES: 'INDEXES'; +MATERIALIZED: 'MATERIALIZED'; NOT: 'NOT'; ON: 'ON'; PARTITION: 'PARTITION'; REFRESH: 'REFRESH'; SHOW: 'SHOW'; TRUE: 'TRUE'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; WITH: 'WITH'; diff --git a/spark/src/main/antlr/SqlBaseLexer.g4 b/spark/src/main/antlr/SqlBaseLexer.g4 index d9128de0f5..e8b5cb012f 100644 --- a/spark/src/main/antlr/SqlBaseLexer.g4 +++ b/spark/src/main/antlr/SqlBaseLexer.g4 @@ -447,6 +447,7 @@ PIPE: '|'; CONCAT_PIPE: '||'; HAT: '^'; COLON: ':'; +DOUBLE_COLON: '::'; ARROW: '->'; FAT_ARROW : '=>'; HENT_START: '/*+'; diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 index 6a6d39e96c..84a31dafed 100644 --- a/spark/src/main/antlr/SqlBaseParser.g4 +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -957,6 +957,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast + | primaryExpression DOUBLE_COLON dataType #castByColon | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first | ANY_VALUE LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #any_value @@ -967,7 +968,6 @@ primaryExpression | qualifiedName DOT ASTERISK #star | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor | LEFT_PAREN query RIGHT_PAREN #subqueryExpression - | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument (COMMA argument+=functionArgument)*)? RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? @@ -1196,6 +1196,7 @@ qualifiedNameList functionName : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | identFunc=IDENTIFIER_KW // IDENTIFIER itself is also a valid function name. | qualifiedName | FILTER | LEFT diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 13db103f4b..18ae47c2b9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -65,14 +65,18 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), sparkExecutionEngineConfig.getClusterName(), - sparkExecutionEngineConfig.getSparkSubmitParameters())); + sparkExecutionEngineConfig.getSparkSubmitParameters(), + createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( + dispatchQueryResponse.getQueryId(), sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), dispatchQueryResponse.isDropIndexQuery(), - dispatchQueryResponse.getResultIndex())); - return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId()); + dispatchQueryResponse.getResultIndex(), + dispatchQueryResponse.getSessionId())); + return new CreateAsyncQueryResponse( + dispatchQueryResponse.getQueryId().getId(), dispatchQueryResponse.getSessionId()); } @Override @@ -81,6 +85,7 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { Optional jobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { + String sessionId = jobMetadata.get().getSessionId(); JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get()); if (JobRunState.SUCCESS.toString().equals(jsonObject.getString(STATUS_FIELD))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = @@ -90,13 +95,18 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { result.add(sparkSqlFunctionResponseHandle.next()); } return new AsyncQueryExecutionResponse( - JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result, null); + JobRunState.SUCCESS.toString(), + sparkSqlFunctionResponseHandle.schema(), + result, + null, + sessionId); } else { return new AsyncQueryExecutionResponse( jsonObject.optString(STATUS_FIELD, JobRunState.FAILED.toString()), null, null, - jsonObject.optString(ERROR_FIELD, "")); + jsonObject.optString(ERROR_FIELD, ""), + sessionId); } } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index a95a6ffe45..6de8c35f03 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -7,166 +7,31 @@ package org.opensearch.sql.spark.asyncquery; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createJobMetaData; + import java.util.Optional; -import org.apache.commons.io.IOUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.DocWriteRequest; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.statestore.StateStore; /** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ +@RequiredArgsConstructor public class OpensearchAsyncQueryJobMetadataStorageService implements AsyncQueryJobMetadataStorageService { - public static final String JOB_METADATA_INDEX = ".ql-job-metadata"; - private static final String JOB_METADATA_INDEX_MAPPING_FILE_NAME = - "job-metadata-index-mapping.yml"; - private static final String JOB_METADATA_INDEX_SETTINGS_FILE_NAME = - "job-metadata-index-settings.yml"; - private static final Logger LOG = LogManager.getLogger(); - private final Client client; - private final ClusterService clusterService; - - /** - * This class implements JobMetadataStorageService interface using OpenSearch as underlying - * storage. - * - * @param client opensearch NodeClient. - * @param clusterService ClusterService. - */ - public OpensearchAsyncQueryJobMetadataStorageService( - Client client, ClusterService clusterService) { - this.client = client; - this.clusterService = clusterService; - } + private final StateStore stateStore; @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { - if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { - createJobMetadataIndex(); - } - IndexRequest indexRequest = new IndexRequest(JOB_METADATA_INDEX); - indexRequest.id(asyncQueryJobMetadata.getJobId()); - indexRequest.opType(DocWriteRequest.OpType.CREATE); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - ActionFuture indexResponseActionFuture; - IndexResponse indexResponse; - try (ThreadContext.StoredContext storedContext = - client.threadPool().getThreadContext().stashContext()) { - indexRequest.source(AsyncQueryJobMetadata.convertToXContent(asyncQueryJobMetadata)); - indexResponseActionFuture = client.index(indexRequest); - indexResponse = indexResponseActionFuture.actionGet(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("JobMetadata : {} successfully created", asyncQueryJobMetadata.getJobId()); - } else { - throw new RuntimeException( - "Saving job metadata information failed with result : " - + indexResponse.getResult().getLowercase()); - } + AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); + createJobMetaData(stateStore, queryId.getDataSourceName()).apply(asyncQueryJobMetadata); } @Override - public Optional getJobMetadata(String jobId) { - if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { - createJobMetadataIndex(); - return Optional.empty(); - } - return searchInJobMetadataIndex(QueryBuilders.termQuery("jobId.keyword", jobId)).stream() - .findFirst(); - } - - private void createJobMetadataIndex() { - try { - InputStream mappingFileStream = - OpensearchAsyncQueryJobMetadataStorageService.class - .getClassLoader() - .getResourceAsStream(JOB_METADATA_INDEX_MAPPING_FILE_NAME); - InputStream settingsFileStream = - OpensearchAsyncQueryJobMetadataStorageService.class - .getClassLoader() - .getResourceAsStream(JOB_METADATA_INDEX_SETTINGS_FILE_NAME); - CreateIndexRequest createIndexRequest = new CreateIndexRequest(JOB_METADATA_INDEX); - createIndexRequest - .mapping(IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML) - .settings( - IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); - ActionFuture createIndexResponseActionFuture; - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); - } - CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); - if (createIndexResponse.isAcknowledged()) { - LOG.info("Index: {} creation Acknowledged", JOB_METADATA_INDEX); - } else { - throw new RuntimeException("Index creation is not acknowledged."); - } - } catch (Throwable e) { - throw new RuntimeException( - "Internal server error while creating" - + JOB_METADATA_INDEX - + " index:: " - + e.getMessage()); - } - } - - private List searchInJobMetadataIndex(QueryBuilder query) { - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(JOB_METADATA_INDEX); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(query); - searchSourceBuilder.size(1); - searchRequest.source(searchSourceBuilder); - // https://github.com/opensearch-project/sql/issues/1801. - searchRequest.preference("_primary_first"); - ActionFuture searchResponseActionFuture; - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - searchResponseActionFuture = client.search(searchRequest); - } - SearchResponse searchResponse = searchResponseActionFuture.actionGet(); - if (searchResponse.status().getStatus() != 200) { - throw new RuntimeException( - "Fetching job metadata information failed with status : " + searchResponse.status()); - } else { - List list = new ArrayList<>(); - for (SearchHit searchHit : searchResponse.getHits().getHits()) { - String sourceAsString = searchHit.getSourceAsString(); - AsyncQueryJobMetadata asyncQueryJobMetadata; - try { - asyncQueryJobMetadata = AsyncQueryJobMetadata.toJobMetadata(sourceAsString); - } catch (IOException e) { - throw new RuntimeException(e); - } - list.add(asyncQueryJobMetadata); - } - return list; - } + public Optional getJobMetadata(String qid) { + AsyncQueryId queryId = new AsyncQueryId(qid); + return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) + .apply(queryId.docId()); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java index d2e54af004..e5d9cffd5f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java @@ -19,4 +19,5 @@ public class AsyncQueryExecutionResponse { private final ExecutionEngine.Schema schema; private final List results; private final String error; + private final String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java new file mode 100644 index 0000000000..b99ebe0e8c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery.model; + +import static org.opensearch.sql.spark.utils.IDUtils.decode; +import static org.opensearch.sql.spark.utils.IDUtils.encode; + +import lombok.Data; + +/** Async query id. */ +@Data +public class AsyncQueryId { + private final String id; + + public static AsyncQueryId newAsyncQueryId(String datasourceName) { + return new AsyncQueryId(encode(datasourceName)); + } + + public String getDataSourceName() { + return decode(id); + } + + /** OpenSearch DocId. */ + public String docId() { + return "qid" + id; + } + + @Override + public String toString() { + return "asyncQueryId=" + id; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index b470ef989f..3c59403661 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -8,34 +8,83 @@ package org.opensearch.sql.spark.asyncquery.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.statement.StatementModel.QUERY_ID; import com.google.gson.Gson; import java.io.IOException; -import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; +import lombok.SneakyThrows; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.statestore.StateModel; /** This class models all the metadata required for a job. */ @Data -@AllArgsConstructor -@EqualsAndHashCode -public class AsyncQueryJobMetadata { - private String applicationId; - private String jobId; - private boolean isDropIndexQuery; - private String resultIndex; +@EqualsAndHashCode(callSuper = false) +public class AsyncQueryJobMetadata extends StateModel { + public static final String TYPE_JOBMETA = "jobmeta"; - public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { + private final AsyncQueryId queryId; + private final String applicationId; + private final String jobId; + private final boolean isDropIndexQuery; + private final String resultIndex; + // optional sessionId. + private final String sessionId; + + @EqualsAndHashCode.Exclude private final long seqNo; + @EqualsAndHashCode.Exclude private final long primaryTerm; + + public AsyncQueryJobMetadata( + AsyncQueryId queryId, String applicationId, String jobId, String resultIndex) { + this( + queryId, + applicationId, + jobId, + false, + resultIndex, + null, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + } + + public AsyncQueryJobMetadata( + AsyncQueryId queryId, + String applicationId, + String jobId, + boolean isDropIndexQuery, + String resultIndex, + String sessionId) { + this( + queryId, + applicationId, + jobId, + isDropIndexQuery, + resultIndex, + sessionId, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + } + + public AsyncQueryJobMetadata( + AsyncQueryId queryId, + String applicationId, + String jobId, + boolean isDropIndexQuery, + String resultIndex, + String sessionId, + long seqNo, + long primaryTerm) { + this.queryId = queryId; this.applicationId = applicationId; this.jobId = jobId; - this.isDropIndexQuery = false; + this.isDropIndexQuery = isDropIndexQuery; this.resultIndex = resultIndex; + this.sessionId = sessionId; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; } @Override @@ -46,38 +95,36 @@ public String toString() { /** * Converts JobMetadata to XContentBuilder. * - * @param metadata metadata. * @return XContentBuilder {@link XContentBuilder} * @throws Exception Exception. */ - public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) throws Exception { - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.field("jobId", metadata.getJobId()); - builder.field("applicationId", metadata.getApplicationId()); - builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); - builder.field("resultIndex", metadata.getResultIndex()); - builder.endObject(); + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(QUERY_ID, queryId.getId()) + .field("type", TYPE_JOBMETA) + .field("jobId", jobId) + .field("applicationId", applicationId) + .field("isDropIndexQuery", isDropIndexQuery) + .field("resultIndex", resultIndex) + .field("sessionId", sessionId) + .endObject(); return builder; } - /** - * Converts json string to DataSourceMetadata. - * - * @param json jsonstring. - * @return jobmetadata {@link AsyncQueryJobMetadata} - * @throws java.io.IOException IOException. - */ - public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOException { - try (XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - json)) { - return toJobMetadata(parser); - } + /** copy builder. update seqNo and primaryTerm */ + public static AsyncQueryJobMetadata copy( + AsyncQueryJobMetadata copy, long seqNo, long primaryTerm) { + return new AsyncQueryJobMetadata( + copy.getQueryId(), + copy.getApplicationId(), + copy.getJobId(), + copy.isDropIndexQuery(), + copy.getResultIndex(), + copy.getSessionId(), + seqNo, + primaryTerm); } /** @@ -87,16 +134,23 @@ public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOExceptio * @return JobMetadata {@link AsyncQueryJobMetadata} * @throws IOException IOException. */ - public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException { + @SneakyThrows + public static AsyncQueryJobMetadata fromXContent( + XContentParser parser, long seqNo, long primaryTerm) { + AsyncQueryId queryId = null; String jobId = null; String applicationId = null; boolean isDropIndexQuery = false; String resultIndex = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String sessionId = null; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { + case QUERY_ID: + queryId = new AsyncQueryId(parser.textOrNull()); + break; case "jobId": jobId = parser.textOrNull(); break; @@ -109,6 +163,11 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "resultIndex": resultIndex = parser.textOrNull(); break; + case "sessionId": + sessionId = parser.textOrNull(); + break; + case "type": + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -116,6 +175,19 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws if (jobId == null || applicationId == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery, resultIndex); + return new AsyncQueryJobMetadata( + queryId, + applicationId, + jobId, + isDropIndexQuery, + resultIndex, + sessionId, + seqNo, + primaryTerm); + } + + @Override + public String getId() { + return queryId.docId(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 0609d8903c..9a73b0f364 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; import static org.opensearch.sql.spark.data.constants.SparkConstants.*; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.net.URI; import java.net.URISyntaxException; @@ -30,6 +31,7 @@ public class SparkSubmitParameters { public static final String SPACE = " "; public static final String EQUALS = "="; + public static final String FLINT_BASIC_AUTH = "basic"; private final String className; private final Map config; @@ -39,7 +41,7 @@ public class SparkSubmitParameters { public static class Builder { - private final String className; + private String className; private final Map config; private String extraParameters; @@ -70,6 +72,11 @@ public static Builder builder() { return new Builder(); } + public Builder className(String className) { + this.className = className; + return this; + } + public Builder dataSource(DataSourceMetadata metadata) { if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); @@ -108,7 +115,7 @@ private void setFlintIndexStoreAuthProperties( Supplier password, Supplier region) { if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) { - config.put(FLINT_INDEX_STORE_AUTH_KEY, authType); + config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH); config.put(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get()); config.put(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get()); } else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) { @@ -141,6 +148,12 @@ public Builder extraParameters(String params) { return this; } + public Builder sessionExecution(String sessionId, String datasourceName) { + config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + config.put(FLINT_JOB_SESSION_ID, sessionId); + return this; + } + public SparkSubmitParameters build() { return new SparkSubmitParameters(className, config, extraParameters); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index c4382239a1..f57c8facee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -7,12 +7,14 @@ import java.util.Map; import lombok.Data; +import lombok.EqualsAndHashCode; /** * This POJO carries all the fields required for emr serverless job submission. Used as model in * {@link EMRServerlessClient} interface. */ @Data +@EqualsAndHashCode public class StartJobRequest { public static final Long DEFAULT_JOB_TIMEOUT = 120L; diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 284afcc0a9..e8659c680c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -21,12 +21,11 @@ public class SparkConstants { public static final String SPARK_SQL_APPLICATION_JAR = "file:///home/hadoop/.ivy2/jars/org.opensearch_opensearch-spark-sql-application_2.12-0.1.0-SNAPSHOT.jar"; public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; + public static final String SPARK_REQUEST_BUFFER_INDEX_NAME = ".query_execution_request"; // TODO should be replaced with mvn jar. public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; // TODO should be replaced with mvn jar. - public static final String FLINT_CATALOG_JAR = - "s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar"; public static final String FLINT_DEFAULT_HOST = "localhost"; public static final String FLINT_DEFAULT_PORT = "9200"; public static final String FLINT_DEFAULT_SCHEME = "http"; @@ -86,4 +85,8 @@ public class SparkConstants { public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; + + public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; + public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; + public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java new file mode 100644 index 0000000000..77a0e1cd09 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import com.amazonaws.services.emrserverless.model.JobRunState; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; + +/** Process async query request. */ +public abstract class AsyncQueryHandler { + + public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { + if (asyncQueryJobMetadata.isDropIndexQuery()) { + return SparkQueryDispatcher.DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()) + .result(); + } + + JSONObject result = getResponseFromResultIndex(asyncQueryJobMetadata); + if (result.has(DATA_FIELD)) { + JSONObject items = result.getJSONObject(DATA_FIELD); + + // If items have STATUS_FIELD, use it; otherwise, mark failed + String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); + result.put(STATUS_FIELD, status); + + // If items have ERROR_FIELD, use it; otherwise, set empty string + String error = items.optString(ERROR_FIELD, ""); + result.put(ERROR_FIELD, error); + return result; + } else { + return getResponseFromExecutor(asyncQueryJobMetadata); + } + } + + protected abstract JSONObject getResponseFromResultIndex( + AsyncQueryJobMetadata asyncQueryJobMetadata); + + protected abstract JSONObject getResponseFromExecutor( + AsyncQueryJobMetadata asyncQueryJobMetadata); + + abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java new file mode 100644 index 0000000000..8a582278e1 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import lombok.RequiredArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class BatchQueryHandler extends AsyncQueryHandler { + private final EMRServerlessClient emrServerlessClient; + private final JobExecutionResponseReader jobExecutionResponseReader; + + @Override + protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + // either empty json when the result is not available or data with status + // Fetch from Result Index + return jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } + + @Override + protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + JSONObject result = new JSONObject(); + // make call to EMR Serverless when related result index documents are not available + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); + return result; + } + + @Override + public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + return asyncQueryJobMetadata.getQueryId().getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java new file mode 100644 index 0000000000..52cc2efbe2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class InteractiveQueryHandler extends AsyncQueryHandler { + private final SessionManager sessionManager; + private final JobExecutionResponseReader jobExecutionResponseReader; + + @Override + protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + return jobExecutionResponseReader.getResultWithQueryId( + queryId, asyncQueryJobMetadata.getResultIndex()); + } + + @Override + protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + JSONObject result = new JSONObject(); + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); + StatementState statementState = statement.getStatementState(); + result.put(STATUS_FIELD, statementState.getState()); + result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse("")); + return result; + } + + @Override + public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId).cancel(); + return queryId; + } + + private Statement getStatementByQueryId(String sid, String qid) { + SessionId sessionId = new SessionId(sid); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(qid); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + return statement.get(); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 347e154885..8feeddcafc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -7,15 +7,15 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; -import com.amazonaws.services.emrserverless.model.CancelJobRunResult; -import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import lombok.AllArgsConstructor; import lombok.Getter; @@ -31,14 +31,17 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; -import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.*; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -50,12 +53,10 @@ public class SparkQueryDispatcher { private static final Logger LOG = LogManager.getLogger(); - public static final String INDEX_TAG_KEY = "index"; public static final String DATASOURCE_TAG_KEY = "datasource"; - public static final String SCHEMA_TAG_KEY = "schema"; - public static final String TABLE_TAG_KEY = "table"; public static final String CLUSTER_NAME_TAG_KEY = "cluster"; + public static final String JOB_TYPE_TAG_KEY = "type"; private EMRServerlessClient emrServerlessClient; @@ -69,6 +70,8 @@ public class SparkQueryDispatcher { private Client client; + private SessionManager sessionManager; + public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { return handleSQLQuery(dispatchQueryRequest); @@ -80,81 +83,69 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - // todo. refactor query process logic in plugin. - if (asyncQueryJobMetadata.isDropIndexQuery()) { - return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); - } - - // either empty json when the result is not available or data with status - // Fetch from Result Index - JSONObject result = - jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); - - // if result index document has a status, we are gonna use the status directly; otherwise, we - // will use emr-s job status. - // That a job is successful does not mean there is no error in execution. For example, even if - // result - // index mapping is incorrect, we still write query result and let the job finish. - // That a job is running does not mean the status is running. For example, index/streaming Query - // is a - // long-running job which runs forever. But we need to return success from the result index - // immediately. - if (result.has(DATA_FIELD)) { - JSONObject items = result.getJSONObject(DATA_FIELD); - - // If items have STATUS_FIELD, use it; otherwise, mark failed - String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); - result.put(STATUS_FIELD, status); - - // If items have ERROR_FIELD, use it; otherwise, set empty string - String error = items.optString(ERROR_FIELD, ""); - result.put(ERROR_FIELD, error); + if (asyncQueryJobMetadata.getSessionId() != null) { + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + .getQueryResponse(asyncQueryJobMetadata); } else { - // make call to EMR Serverless when related result index documents are not available - GetJobRunResult getJobRunResult = - emrServerlessClient.getJobRunResult( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - String jobState = getJobRunResult.getJobRun().getState(); - result.put(STATUS_FIELD, jobState); - result.put(ERROR_FIELD, ""); + return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) + .getQueryResponse(asyncQueryJobMetadata); } - - return result; } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - CancelJobRunResult cancelJobRunResult = - emrServerlessClient.cancelJobRun( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - return cancelJobRunResult.getJobRunId(); + if (asyncQueryJobMetadata.getSessionId() != null) { + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + .cancelJob(asyncQueryJobMetadata); + } else { + return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) + .cancelJob(asyncQueryJobMetadata); + } } private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) { - if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) { - IndexDetails indexDetails = + if (SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { + IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); - if (indexDetails.isDropIndex()) { - return handleDropIndexQuery(dispatchQueryRequest, indexDetails); + fillMissingDetails(dispatchQueryRequest, indexQueryDetails); + + // TODO: refactor this code properly. + if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) { + return handleDropIndexQuery(dispatchQueryRequest, indexQueryDetails); + } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) { + return handleStreamingQueries(dispatchQueryRequest, indexQueryDetails); } else { - return handleIndexQuery(dispatchQueryRequest, indexDetails); + return handleFlintNonStreamingQueries(dispatchQueryRequest, indexQueryDetails); } } else { return handleNonIndexQuery(dispatchQueryRequest); } } - private DispatchQueryResponse handleIndexQuery( - DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { - FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + // TODO: Revisit this logic. + // Currently, Spark if datasource is not provided in query. + // Spark Assumes the datasource to be catalog. + // This is required to handle drop index case properly when datasource name is not provided. + private static void fillMissingDetails( + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { + if (indexQueryDetails.getFullyQualifiedTableName() != null + && indexQueryDetails.getFullyQualifiedTableName().getDatasourceName() == null) { + indexQueryDetails + .getFullyQualifiedTableName() + .setDatasourceName(dispatchQueryRequest.getDatasource()); + } + } + + private DispatchQueryResponse handleStreamingQueries( + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - tags.put(INDEX_TAG_KEY, indexDetails.getIndexName()); - tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName()); - tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName()); + tags.put(INDEX_TAG_KEY, indexQueryDetails.openSearchIndexName()); + if (indexQueryDetails.isAutoRefresh()) { + tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); + } StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), @@ -165,22 +156,28 @@ private DispatchQueryResponse handleIndexQuery( .dataSource( dataSourceService.getRawDataSourceMetadata( dispatchQueryRequest.getDatasource())) - .structuredStreaming(indexDetails.getAutoRefresh()) + .structuredStreaming(indexQueryDetails.isAutoRefresh()) .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() .toString(), tags, - indexDetails.getAutoRefresh(), + indexQueryDetails.isAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + return new DispatchQueryResponse( + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + jobId, + false, + dataSourceMetadata.getResultIndex(), + null); } - private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { + private DispatchQueryResponse handleFlintNonStreamingQueries( + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; + String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); StartJobRequest startJobRequest = new StartJobRequest( @@ -196,18 +193,96 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ .build() .toString(), tags, - false, + indexQueryDetails.isAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + return new DispatchQueryResponse( + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + jobId, + false, + dataSourceMetadata.getResultIndex(), + null); + } + + private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { + DataSourceMetadata dataSourceMetadata = + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + AsyncQueryId queryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); + String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; + Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); + + if (sessionManager.isEnabled()) { + Session session = null; + + if (dispatchQueryRequest.getSessionId() != null) { + // get session from request + SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); + Optional createdSession = sessionManager.getSession(sessionId); + if (createdSession.isEmpty()) { + throw new IllegalArgumentException("no session found. " + sessionId); + } + session = createdSession.get(); + } + if (session == null || !session.isReady()) { + // create session if not exist or session dead/fail + tags.put(JOB_TYPE_TAG_KEY, JobType.INTERACTIVE.getText()); + session = + sessionManager.createSession( + new CreateSessionRequest( + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .className(FLINT_SESSION_CLASS_NAME) + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), + tags, + dataSourceMetadata.getResultIndex(), + dataSourceMetadata.getName())); + } + session.submit( + new QueryRequest( + queryId, dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); + return new DispatchQueryResponse( + queryId, + session.getSessionModel().getJobId(), + false, + dataSourceMetadata.getResultIndex(), + session.getSessionId().getSessionId()); + } else { + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + false, + dataSourceMetadata.getResultIndex()); + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse( + queryId, jobId, false, dataSourceMetadata.getResultIndex(), null); + } } private DispatchQueryResponse handleDropIndexQuery( - DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata(indexDetails); + FlintIndexMetadata indexMetadata = + flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails); // if index is created without auto refresh. there is no job to cancel. String status = JobRunState.FAILED.toString(); try { @@ -216,7 +291,7 @@ private DispatchQueryResponse handleDropIndexQuery( dispatchQueryRequest.getApplicationId(), indexMetadata.getJobId()); } } finally { - String indexName = indexDetails.openSearchIndexName(); + String indexName = indexQueryDetails.openSearchIndexName(); try { AcknowledgedResponse response = client.admin().indices().delete(new DeleteIndexRequest().indices(indexName)).get(); @@ -229,7 +304,11 @@ private DispatchQueryResponse handleDropIndexQuery( } } return new DispatchQueryResponse( - new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex()); + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + new DropIndexResult(status).toJobId(), + true, + dataSourceMetadata.getResultIndex(), + null); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 823a4570ce..6aa28227a1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -23,4 +23,7 @@ public class DispatchQueryRequest { /** Optional extra Spark submit parameters to include in final request */ private String extraSparkSubmitParams; + + /** Optional sessionId. */ + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 9ee5f156f2..e44379daff 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -2,11 +2,14 @@ import lombok.AllArgsConstructor; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Data @AllArgsConstructor public class DispatchQueryResponse { + private AsyncQueryId queryId; private String jobId; private boolean isDropIndexQuery; private String resultIndex; + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java deleted file mode 100644 index 1cc66da9fc..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.dispatcher.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.opensearch.sql.spark.flint.FlintIndexType; - -/** Index details in an async query. */ -@Data -@AllArgsConstructor -@NoArgsConstructor -@EqualsAndHashCode -public class IndexDetails { - private String indexName; - private FullyQualifiedTableName fullyQualifiedTableName; - // by default, auto_refresh = false; - private Boolean autoRefresh = false; - private boolean isDropIndex; - private FlintIndexType indexType; - - public String openSearchIndexName() { - FullyQualifiedTableName fullyQualifiedTableName = getFullyQualifiedTableName(); - if (FlintIndexType.SKIPPING.equals(getIndexType())) { - String indexName = - "flint" - + "_" - + fullyQualifiedTableName.getDatasourceName() - + "_" - + fullyQualifiedTableName.getSchemaName() - + "_" - + fullyQualifiedTableName.getTableName() - + "_" - + getIndexType().getSuffix(); - return indexName.toLowerCase(); - } else if (FlintIndexType.COVERING.equals(getIndexType())) { - String indexName = - "flint" - + "_" - + fullyQualifiedTableName.getDatasourceName() - + "_" - + fullyQualifiedTableName.getSchemaName() - + "_" - + fullyQualifiedTableName.getTableName() - + "_" - + getIndexName() - + "_" - + getIndexType().getSuffix(); - return indexName.toLowerCase(); - } else { - throw new UnsupportedOperationException( - String.format("Unsupported Index Type : %s", getIndexType())); - } - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryActionType.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryActionType.java new file mode 100644 index 0000000000..2c96511d2a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryActionType.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +/** Enum for Index Action in the given query.* */ +public enum IndexQueryActionType { + CREATE, + REFRESH, + DESCRIBE, + SHOW, + DROP +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryDetails.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryDetails.java new file mode 100644 index 0000000000..5b4326a10e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryDetails.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.sql.spark.flint.FlintIndexType; + +/** Index details in an async query. */ +@Getter +@EqualsAndHashCode +public class IndexQueryDetails { + + public static final String STRIP_CHARS = "`"; + + private String indexName; + private FullyQualifiedTableName fullyQualifiedTableName; + // by default, auto_refresh = false; + private boolean autoRefresh; + private IndexQueryActionType indexQueryActionType; + // materialized view special case where + // table name and mv name are combined. + private String mvName; + private FlintIndexType indexType; + + private IndexQueryDetails() {} + + public static IndexQueryDetailsBuilder builder() { + return new IndexQueryDetailsBuilder(); + } + + // Builder class + public static class IndexQueryDetailsBuilder { + private final IndexQueryDetails indexQueryDetails; + + public IndexQueryDetailsBuilder() { + indexQueryDetails = new IndexQueryDetails(); + } + + public IndexQueryDetailsBuilder indexName(String indexName) { + indexQueryDetails.indexName = indexName; + return this; + } + + public IndexQueryDetailsBuilder fullyQualifiedTableName(FullyQualifiedTableName tableName) { + indexQueryDetails.fullyQualifiedTableName = tableName; + return this; + } + + public IndexQueryDetailsBuilder autoRefresh(Boolean autoRefresh) { + indexQueryDetails.autoRefresh = autoRefresh; + return this; + } + + public IndexQueryDetailsBuilder indexQueryActionType( + IndexQueryActionType indexQueryActionType) { + indexQueryDetails.indexQueryActionType = indexQueryActionType; + return this; + } + + public IndexQueryDetailsBuilder mvName(String mvName) { + indexQueryDetails.mvName = mvName; + return this; + } + + public IndexQueryDetailsBuilder indexType(FlintIndexType indexType) { + indexQueryDetails.indexType = indexType; + return this; + } + + public IndexQueryDetails build() { + return indexQueryDetails; + } + } + + public String openSearchIndexName() { + FullyQualifiedTableName fullyQualifiedTableName = getFullyQualifiedTableName(); + String indexName = StringUtils.EMPTY; + switch (getIndexType()) { + case COVERING: + indexName = + "flint" + + "_" + + StringUtils.strip(fullyQualifiedTableName.getDatasourceName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getSchemaName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getTableName(), STRIP_CHARS) + + "_" + + StringUtils.strip(getIndexName(), STRIP_CHARS) + + "_" + + getIndexType().getSuffix(); + break; + case SKIPPING: + indexName = + "flint" + + "_" + + StringUtils.strip(fullyQualifiedTableName.getDatasourceName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getSchemaName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getTableName(), STRIP_CHARS) + + "_" + + getIndexType().getSuffix(); + break; + case MATERIALIZED_VIEW: + indexName = "flint" + "_" + StringUtils.strip(getMvName(), STRIP_CHARS).toLowerCase(); + break; + } + return indexName.toLowerCase(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java new file mode 100644 index 0000000000..01f5f422e9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +public enum JobType { + INTERACTIVE("interactive"), + STREAMING("streaming"), + BATCH("batch"); + + private String text; + + JobType(String text) { + this.text = text; + } + + public String getText() { + return this.text; + } + + /** + * Get JobType from text. + * + * @param text text. + * @return JobType {@link JobType}. + */ + public static JobType fromString(String text) { + for (JobType JobType : JobType.values()) { + if (JobType.text.equalsIgnoreCase(text)) { + return JobType; + } + } + throw new IllegalArgumentException("No JobType with text " + text + " found"); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java new file mode 100644 index 0000000000..b2201fbd01 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Map; +import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.client.StartJobRequest; + +@Data +public class CreateSessionRequest { + private final String jobName; + private final String applicationId; + private final String executionRoleArn; + private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder; + private final Map tags; + private final String resultIndex; + private final String datasourceName; + + public StartJobRequest getStartJobRequest() { + return new InteractiveSessionStartJobRequest( + "select 1", + jobName, + applicationId, + executionRoleArn, + sparkSubmitParametersBuilder.build().toString(), + tags, + resultIndex); + } + + static class InteractiveSessionStartJobRequest extends StartJobRequest { + public InteractiveSessionStartJobRequest( + String query, + String jobName, + String applicationId, + String executionRoleArn, + String sparkSubmitParams, + Map tags, + String resultIndex) { + super( + query, + jobName, + applicationId, + executionRoleArn, + sparkSubmitParams, + tags, + false, + resultIndex); + } + + /** Interactive query keep running. */ + @Override + public Long executionTimeout() { + return 0L; + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java new file mode 100644 index 0000000000..3221b33b2c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.session.SessionState.DEAD; +import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; +import static org.opensearch.sql.spark.execution.session.SessionState.FAIL; +import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; + +import java.util.Optional; +import lombok.Builder; +import lombok.Getter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; + +/** + * Interactive session. + * + *

ENTRY_STATE: not_started + */ +@Getter +@Builder +public class InteractiveSession implements Session { + private static final Logger LOG = LogManager.getLogger(); + + public static final String SESSION_ID_TAG_KEY = "sid"; + + private final SessionId sessionId; + private final StateStore stateStore; + private final EMRServerlessClient serverlessClient; + private SessionModel sessionModel; + + @Override + public void open(CreateSessionRequest createSessionRequest) { + try { + // append session id; + createSessionRequest + .getSparkSubmitParametersBuilder() + .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); + createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId()); + String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); + String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); + + sessionModel = + initInteractiveSession( + applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); + createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "session already exist. " + sessionId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + /** todo. StatementSweeper will delete doc. */ + @Override + public void close() { + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + if (model.isEmpty()) { + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); + } else { + serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); + } + } + + /** Submit statement. If submit successfully, Statement in waiting state. */ + public StatementId submit(QueryRequest request) { + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + if (model.isEmpty()) { + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); + } else { + sessionModel = model.get(); + if (!END_STATE.contains(sessionModel.getSessionState())) { + String qid = request.getQueryId().getId(); + StatementId statementId = newStatementId(qid); + Statement st = + Statement.builder() + .sessionId(sessionId) + .applicationId(sessionModel.getApplicationId()) + .jobId(sessionModel.getJobId()) + .stateStore(stateStore) + .statementId(statementId) + .langType(LangType.SQL) + .datasourceName(sessionModel.getDatasourceName()) + .query(request.getQuery()) + .queryId(qid) + .build(); + st.open(); + return statementId; + } else { + String errMsg = + String.format( + "can't submit statement, session should not be in end state, " + + "current session state is: %s", + sessionModel.getSessionState().getSessionState()); + LOG.debug(errMsg); + throw new IllegalStateException(errMsg); + } + } + } + + @Override + public Optional get(StatementId stID) { + return StateStore.getStatement(stateStore, sessionModel.getDatasourceName()) + .apply(stID.getId()) + .map( + model -> + Statement.builder() + .sessionId(sessionId) + .applicationId(model.getApplicationId()) + .jobId(model.getJobId()) + .statementId(model.getStatementId()) + .langType(model.getLangType()) + .query(model.getQuery()) + .queryId(model.getQueryId()) + .stateStore(stateStore) + .statementModel(model) + .build()); + } + + @Override + public boolean isReady() { + return sessionModel.getSessionState() != DEAD && sessionModel.getSessionState() != FAIL; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java new file mode 100644 index 0000000000..d3d3411ded --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Optional; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; + +/** Session define the statement execution context. Each session is binding to one Spark Job. */ +public interface Session { + /** open session. */ + void open(CreateSessionRequest createSessionRequest); + + /** close session. */ + void close(); + + /** + * submit {@link QueryRequest}. + * + * @param request {@link QueryRequest} + * @return {@link StatementId} + */ + StatementId submit(QueryRequest request); + + /** + * get {@link Statement}. + * + * @param stID {@link StatementId} + * @return {@link Statement} + */ + Optional get(StatementId stID); + + SessionModel getSessionModel(); + + SessionId getSessionId(); + + /** return true if session is ready to use. */ + boolean isReady(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java new file mode 100644 index 0000000000..c85e4dd35c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.utils.IDUtils.decode; +import static org.opensearch.sql.spark.utils.IDUtils.encode; + +import lombok.Data; + +@Data +public class SessionId { + public static final int PREFIX_LEN = 10; + + private final String sessionId; + + public static SessionId newSessionId(String datasourceName) { + return new SessionId(encode(datasourceName)); + } + + public String getDataSourceName() { + return decode(sessionId); + } + + @Override + public String toString() { + return "sessionId=" + sessionId; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java new file mode 100644 index 0000000000..81b9fdaee0 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_LIMIT; +import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; +import static org.opensearch.sql.spark.execution.statestore.StateStore.activeSessionsCount; + +import java.util.Locale; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.StateStore; + +/** + * Singleton Class + * + *

todo. add Session cache and Session sweeper. + */ +@RequiredArgsConstructor +public class SessionManager { + private final StateStore stateStore; + private final EMRServerlessClient emrServerlessClient; + private final Settings settings; + + public Session createSession(CreateSessionRequest request) { + int sessionMaxLimit = sessionMaxLimit(); + if (activeSessionsCount(stateStore, request.getDatasourceName()).get() >= sessionMaxLimit) { + String errorMsg = + String.format( + Locale.ROOT, + "The maximum number of active sessions can be " + "supported is %d", + sessionMaxLimit); + throw new IllegalArgumentException(errorMsg); + } + InteractiveSession session = + InteractiveSession.builder() + .sessionId(newSessionId(request.getDatasourceName())) + .stateStore(stateStore) + .serverlessClient(emrServerlessClient) + .build(); + session.open(request); + return session; + } + + public Optional getSession(SessionId sid) { + Optional model = + StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId()); + if (model.isPresent()) { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sid) + .stateStore(stateStore) + .serverlessClient(emrServerlessClient) + .sessionModel(model.get()) + .build(); + return Optional.ofNullable(session); + } + return Optional.empty(); + } + + public boolean isEnabled() { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } + + public int sessionMaxLimit() { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_LIMIT); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java new file mode 100644 index 0000000000..806cdb083e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -0,0 +1,169 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.session.SessionType.INTERACTIVE; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.statestore.StateModel; + +/** Session data in flint.ql.sessions index. */ +@Data +@Builder +public class SessionModel extends StateModel { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String SESSION_TYPE = "sessionType"; + public static final String SESSION_ID = "sessionId"; + public static final String SESSION_STATE = "state"; + public static final String DATASOURCE_NAME = "dataSourceName"; + public static final String LAST_UPDATE_TIME = "lastUpdateTime"; + public static final String APPLICATION_ID = "applicationId"; + public static final String JOB_ID = "jobId"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String SESSION_DOC_TYPE = "session"; + + private final String version; + private final SessionType sessionType; + private final SessionId sessionId; + private final SessionState sessionState; + private final String applicationId; + private final String jobId; + private final String datasourceName; + private final String error; + private final long lastUpdateTime; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, SESSION_DOC_TYPE) + .field(SESSION_TYPE, sessionType.getSessionType()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(SESSION_STATE, sessionState.getSessionState()) + .field(DATASOURCE_NAME, datasourceName) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) + .field(LAST_UPDATE_TIME, lastUpdateTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(copy.sessionState) + .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static SessionModel copyWithState( + SessionModel copy, SessionState state, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(state) + .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static SessionModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + SessionModelBuilder builder = new SessionModelBuilder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case SESSION_TYPE: + builder.sessionType(SessionType.fromString(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case SESSION_STATE: + builder.sessionState(SessionState.fromString(parser.text())); + break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; + case ERROR: + builder.error(parser.text()); + break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; + case LAST_UPDATE_TIME: + builder.lastUpdateTime(parser.longValue()); + break; + case TYPE: + // do nothing. + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static SessionModel initInteractiveSession( + String applicationId, String jobId, SessionId sid, String datasourceName) { + return builder() + .version("1.0") + .sessionType(INTERACTIVE) + .sessionId(sid) + .sessionState(NOT_STARTED) + .datasourceName(datasourceName) + .applicationId(applicationId) + .jobId(jobId) + .error(UNKNOWN) + .lastUpdateTime(System.currentTimeMillis()) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } + + @Override + public String getId() { + return sessionId.getSessionId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java new file mode 100644 index 0000000000..bd5d14c603 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import com.google.common.collect.ImmutableList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum SessionState { + NOT_STARTED("not_started"), + RUNNING("running"), + DEAD("dead"), + FAIL("fail"); + + public static List END_STATE = ImmutableList.of(DEAD, FAIL); + + private final String sessionState; + + SessionState(String sessionState) { + this.sessionState = sessionState; + } + + private static Map STATES = + Arrays.stream(SessionState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static SessionState fromString(String key) { + for (SessionState ss : SessionState.values()) { + if (ss.getSessionState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } + } + throw new IllegalArgumentException("Invalid session state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java new file mode 100644 index 0000000000..10b9ce7bd5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Locale; +import lombok.Getter; + +@Getter +public enum SessionType { + INTERACTIVE("interactive"); + + private final String sessionType; + + SessionType(String sessionType) { + this.sessionType = sessionType; + } + + public static SessionType fromString(String key) { + for (SessionType sType : SessionType.values()) { + if (sType.getSessionType().toLowerCase(Locale.ROOT).equals(key)) { + return sType; + } + } + throw new IllegalArgumentException("Invalid session type: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java new file mode 100644 index 0000000000..c365265224 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.opensearch.sql.spark.rest.model.LangType; + +@Data +public class QueryRequest { + private final AsyncQueryId queryId; + private final LangType langType; + private final String query; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java new file mode 100644 index 0000000000..94c1f79511 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement represent query to execute in session. One statement map to one session. */ +@Getter +@Builder +public class Statement { + private static final Logger LOG = LogManager.getLogger(); + + private final SessionId sessionId; + private final String applicationId; + private final String jobId; + private final StatementId statementId; + private final LangType langType; + private final String datasourceName; + private final String query; + private final String queryId; + private final StateStore stateStore; + + @Setter private StatementModel statementModel; + + /** Open a statement. */ + public void open() { + try { + statementModel = + submitStatement( + sessionId, + applicationId, + jobId, + statementId, + langType, + datasourceName, + query, + queryId); + statementModel = createStatement(stateStore, datasourceName).apply(statementModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "statement already exist. " + statementId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + /** Cancel a statement. */ + public void cancel() { + StatementState statementState = statementModel.getStatementState(); + + if (statementState.equals(StatementState.SUCCESS) + || statementState.equals(StatementState.FAILED) + || statementState.equals(StatementState.CANCELLED)) { + String errorMsg = + String.format( + "can't cancel statement in %s state. statement: %s.", + statementState.getState(), statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + try { + this.statementModel = + updateStatementState(stateStore, statementModel.getDatasourceName()) + .apply(this.statementModel, StatementState.CANCELLED); + } catch (DocumentMissingException e) { + String errorMsg = + String.format("cancel statement failed. no statement found. statement: %s.", statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } catch (VersionConflictEngineException e) { + this.statementModel = + getStatement(stateStore, statementModel.getDatasourceName()) + .apply(statementModel.getId()) + .orElse(this.statementModel); + String errorMsg = + String.format( + "cancel statement failed. current statementState: %s " + "statement: %s.", + this.statementModel.getStatementState(), statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + public StatementState getStatementState() { + return statementModel.getStatementState(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java new file mode 100644 index 0000000000..33284c4b3d --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; + +@Data +public class StatementId { + private final String id; + + // construct statementId from queryId. + public static StatementId newStatementId(String qid) { + return new StatementId(qid); + } + + @Override + public String toString() { + return "statementId=" + id; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java new file mode 100644 index 0000000000..2a1043bf73 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -0,0 +1,204 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.DATASOURCE_NAME; +import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateModel; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement data in flint.ql.sessions index. */ +@Data +@Builder +public class StatementModel extends StateModel { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String STATEMENT_STATE = "state"; + public static final String STATEMENT_ID = "statementId"; + public static final String SESSION_ID = "sessionId"; + public static final String LANG = "lang"; + public static final String QUERY = "query"; + public static final String QUERY_ID = "queryId"; + public static final String SUBMIT_TIME = "submitTime"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String STATEMENT_DOC_TYPE = "statement"; + + private final String version; + private final StatementState statementState; + private final StatementId statementId; + private final SessionId sessionId; + private final String applicationId; + private final String jobId; + private final LangType langType; + private final String datasourceName; + private final String query; + private final String queryId; + private final long submitTime; + private final String error; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, STATEMENT_DOC_TYPE) + .field(STATEMENT_STATE, statementState.getState()) + .field(STATEMENT_ID, statementId.getId()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) + .field(LANG, langType.getText()) + .field(DATASOURCE_NAME, datasourceName) + .field(QUERY, query) + .field(QUERY_ID, queryId) + .field(SUBMIT_TIME, submitTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static StatementModel copy(StatementModel copy, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(copy.statementState) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .langType(copy.langType) + .datasourceName(copy.datasourceName) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static StatementModel copyWithState( + StatementModel copy, StatementState state, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(state) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .langType(copy.langType) + .datasourceName(copy.datasourceName) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static StatementModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + StatementModel.StatementModelBuilder builder = StatementModel.builder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case TYPE: + // do nothing + break; + case STATEMENT_STATE: + builder.statementState(StatementState.fromString(parser.text())); + break; + case STATEMENT_ID: + builder.statementId(new StatementId(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; + case LANG: + builder.langType(LangType.fromString(parser.text())); + break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; + case QUERY: + builder.query(parser.text()); + break; + case QUERY_ID: + builder.queryId(parser.text()); + break; + case SUBMIT_TIME: + builder.submitTime(parser.longValue()); + break; + case ERROR: + builder.error(parser.text()); + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static StatementModel submitStatement( + SessionId sid, + String applicationId, + String jobId, + StatementId statementId, + LangType langType, + String datasourceName, + String query, + String queryId) { + return builder() + .version("1.0") + .statementState(WAITING) + .statementId(statementId) + .sessionId(sid) + .applicationId(applicationId) + .jobId(jobId) + .langType(langType) + .datasourceName(datasourceName) + .query(query) + .queryId(queryId) + .submitTime(System.currentTimeMillis()) + .error(UNKNOWN) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } + + @Override + public String getId() { + return statementId.getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java new file mode 100644 index 0000000000..48978ff8f9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +/** {@link Statement} State. */ +@Getter +public enum StatementState { + WAITING("waiting"), + RUNNING("running"), + SUCCESS("success"), + FAILED("failed"), + CANCELLED("cancelled"); + + private final String state; + + StatementState(String state) { + this.state = state; + } + + private static Map STATES = + Arrays.stream(StatementState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static StatementState fromString(String key) { + for (StatementState ss : StatementState.values()) { + if (ss.getState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } + } + throw new IllegalArgumentException("Invalid statement state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java new file mode 100644 index 0000000000..b5bf31a6ba --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; + +public abstract class StateModel implements ToXContentObject { + + public abstract String getId(); + + public abstract long getSeqNo(); + + public abstract long getPrimaryTerm(); + + public interface CopyBuilder { + T of(T copy, long seqNo, long primaryTerm); + } + + public interface StateCopyBuilder { + T of(T copy, S state, long seqNo, long primaryTerm); + } + + public interface FromXContent { + T fromXContent(XContentParser parser, long seqNo, long primaryTerm); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java new file mode 100644 index 0000000000..e6bad9fc26 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -0,0 +1,307 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Locale; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import lombok.RequiredArgsConstructor; +import org.apache.commons.io.IOUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.session.SessionType; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; + +/** + * State Store maintain the state of Session and Statement. State State create/update/get doc on + * index regardless user FGAC permissions. + */ +@RequiredArgsConstructor +public class StateStore { + public static String SETTINGS_FILE_NAME = "query_execution_request_settings.yml"; + public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; + public static Function DATASOURCE_TO_REQUEST_INDEX = + datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + + private static final Logger LOG = LogManager.getLogger(); + + private final Client client; + private final ClusterService clusterService; + + protected T create( + T st, StateModel.CopyBuilder builder, String indexName) { + try { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(indexName); + } + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(st.getId()) + .source(st.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(st.getSeqNo()) + .setIfPrimaryTerm(st.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Optional get( + String sid, StateModel.FromXContent builder, String indexName) { + try { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(indexName); + return Optional.empty(); + } + GetRequest getRequest = new GetRequest().index(indexName).id(sid).refresh(true); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected T updateState( + T st, S state, StateModel.StateCopyBuilder builder, String indexName) { + try { + T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); + UpdateRequest updateRequest = + new UpdateRequest() + .index(indexName) + .id(model.getId()) + .setIfSeqNo(model.getSeqNo()) + .setIfPrimaryTerm(model.getPrimaryTerm()) + .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .fetchSource(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of( + model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private void createIndex(String indexName) { + try { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest + .mapping(loadConfigFromResource(MAPPING_FILE_NAME), XContentType.YAML) + .settings(loadConfigFromResource(SETTINGS_FILE_NAME), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", indexName); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + indexName + " index:: " + e.getMessage()); + } + } + + private long count(String indexName, QueryBuilder query) { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + return 0; + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchSourceBuilder.size(0); + + // https://github.com/opensearch-project/sql/issues/1801. + SearchRequest searchRequest = + new SearchRequest() + .indices(indexName) + .preference("_primary_first") + .source(searchSourceBuilder); + + ActionFuture searchResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + searchResponseActionFuture = client.search(searchRequest); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching job metadata information failed with status : " + searchResponse.status()); + } else { + return searchResponse.getHits().getTotalHits().value; + } + } + + private String loadConfigFromResource(String fileName) throws IOException { + InputStream fileStream = StateStore.class.getClassLoader().getResourceAsStream(fileName); + return IOUtils.toString(fileStream, StandardCharsets.UTF_8); + } + + /** Helper Functions */ + public static Function createStatement( + StateStore stateStore, String datasourceName) { + return (st) -> + stateStore.create( + st, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function> getStatement( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, StatementModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static BiFunction updateStatementState( + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + StatementModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function createSession( + StateStore stateStore, String datasourceName) { + return (session) -> + stateStore.create( + session, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function> getSession( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static BiFunction updateSessionState( + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + SessionModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function createJobMetaData( + StateStore stateStore, String datasourceName) { + return (jobMetadata) -> + stateStore.create( + jobMetadata, + AsyncQueryJobMetadata::copy, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function> getJobMetaData( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, + AsyncQueryJobMetadata::fromXContent, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Supplier activeSessionsCount(StateStore stateStore, String datasourceName) { + return () -> + stateStore.count( + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName), + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery(SessionModel.TYPE, SessionModel.SESSION_DOC_TYPE)) + .must( + QueryBuilders.termQuery( + SessionModel.SESSION_TYPE, SessionType.INTERACTIVE.getSessionType())) + .must(QueryBuilders.termQuery(SessionModel.DATASOURCE_NAME, datasourceName)) + .must( + QueryBuilders.termQuery( + SessionModel.SESSION_STATE, SessionState.RUNNING.getSessionState()))); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java index e4a5e92035..d4a8e7ddbf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java @@ -1,6 +1,6 @@ package org.opensearch.sql.spark.flint; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; /** Interface for FlintIndexMetadataReader */ public interface FlintIndexMetadataReader { @@ -8,8 +8,8 @@ public interface FlintIndexMetadataReader { /** * Given Index details, get the streaming job Id. * - * @param indexDetails indexDetails. + * @param indexQueryDetails indexDetails. * @return FlintIndexMetadata. */ - FlintIndexMetadata getFlintIndexMetadata(IndexDetails indexDetails); + FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexQueryDetails); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java index 5f712e65cd..a16d0b9138 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java @@ -5,7 +5,7 @@ import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.MappingMetadata; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; /** Implementation of {@link FlintIndexMetadataReader} */ @AllArgsConstructor @@ -14,8 +14,8 @@ public class FlintIndexMetadataReaderImpl implements FlintIndexMetadataReader { private final Client client; @Override - public FlintIndexMetadata getFlintIndexMetadata(IndexDetails indexDetails) { - String indexName = indexDetails.openSearchIndexName(); + public FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexQueryDetails) { + String indexName = indexQueryDetails.openSearchIndexName(); GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings(indexName).get(); try { diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java index d3cbd68dce..2614992463 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -39,6 +39,10 @@ public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex); } + public JSONObject getResultWithQueryId(String queryId, String resultIndex) { + return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultIndex); + } + private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) { SearchRequest searchRequest = new SearchRequest(); String searchResultIndex = resultIndex == null ? SPARK_RESPONSE_BUFFER_INDEX_NAME : resultIndex; diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 8802630d9f..6acf6bc9a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.rest.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_ID; import java.io.IOException; import lombok.Data; @@ -18,6 +19,8 @@ public class CreateAsyncQueryRequest { private String query; private String datasource; private LangType lang; + // optional sessionId + private String sessionId; public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.query = Validate.notNull(query, "Query can't be null"); @@ -25,11 +28,19 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.lang = Validate.notNull(lang, "lang can't be null"); } + public CreateAsyncQueryRequest(String query, String datasource, LangType lang, String sessionId) { + this.query = Validate.notNull(query, "Query can't be null"); + this.datasource = Validate.notNull(datasource, "Datasource can't be null"); + this.lang = Validate.notNull(lang, "lang can't be null"); + this.sessionId = sessionId; + } + public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; LangType lang = null; String datasource = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -41,10 +52,12 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) lang = LangType.fromString(langString); } else if (fieldName.equals("datasource")) { datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } - return new CreateAsyncQueryRequest(query, datasource, lang); + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java index 8cfe57c2a6..2f918308c4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java @@ -12,4 +12,6 @@ @AllArgsConstructor public class CreateAsyncQueryResponse { private String queryId; + // optional sessionId + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java new file mode 100644 index 0000000000..438d2342b4 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.RandomStringUtils; + +@UtilityClass +public class IDUtils { + public static final int PREFIX_LEN = 10; + + public static String decode(String id) { + return new String(Base64.getDecoder().decode(id)).substring(PREFIX_LEN); + } + + public static String encode(String datasourceName) { + String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; + return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index f6b75d49ef..c1f3f02576 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -20,7 +20,8 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.flint.FlintIndexType; /** @@ -42,7 +43,7 @@ public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQ return sparkSqlTableNameVisitor.getFullyQualifiedTableName(); } - public static IndexDetails extractIndexDetails(String sqlQuery) { + public static IndexQueryDetails extractIndexDetails(String sqlQuery) { FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = new FlintSparkSqlExtensionsParser( new CommonTokenStream( @@ -52,10 +53,10 @@ public static IndexDetails extractIndexDetails(String sqlQuery) { flintSparkSqlExtensionsParser.statement(); FlintSQLIndexDetailsVisitor flintSQLIndexDetailsVisitor = new FlintSQLIndexDetailsVisitor(); statementContext.accept(flintSQLIndexDetailsVisitor); - return flintSQLIndexDetailsVisitor.getIndexDetails(); + return flintSQLIndexDetailsVisitor.getIndexQueryDetailsBuilder().build(); } - public static boolean isIndexQuery(String sqlQuery) { + public static boolean isFlintExtensionQuery(String sqlQuery) { FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = new FlintSparkSqlExtensionsParser( new CommonTokenStream( @@ -117,29 +118,29 @@ public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) { public static class FlintSQLIndexDetailsVisitor extends FlintSparkSqlExtensionsBaseVisitor { - @Getter private final IndexDetails indexDetails; + @Getter private final IndexQueryDetails.IndexQueryDetailsBuilder indexQueryDetailsBuilder; public FlintSQLIndexDetailsVisitor() { - this.indexDetails = new IndexDetails(); + this.indexQueryDetailsBuilder = new IndexQueryDetails.IndexQueryDetailsBuilder(); } @Override public Void visitIndexName(FlintSparkSqlExtensionsParser.IndexNameContext ctx) { - indexDetails.setIndexName(ctx.getText()); + indexQueryDetailsBuilder.indexName(ctx.getText()); return super.visitIndexName(ctx); } @Override public Void visitTableName(FlintSparkSqlExtensionsParser.TableNameContext ctx) { - indexDetails.setFullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); + indexQueryDetailsBuilder.fullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); return super.visitTableName(ctx); } @Override public Void visitCreateSkippingIndexStatement( FlintSparkSqlExtensionsParser.CreateSkippingIndexStatementContext ctx) { - indexDetails.setDropIndex(false); - indexDetails.setIndexType(FlintIndexType.SKIPPING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.CREATE); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); visitPropertyList(ctx.propertyList()); return super.visitCreateSkippingIndexStatement(ctx); } @@ -147,28 +148,113 @@ public Void visitCreateSkippingIndexStatement( @Override public Void visitCreateCoveringIndexStatement( FlintSparkSqlExtensionsParser.CreateCoveringIndexStatementContext ctx) { - indexDetails.setDropIndex(false); - indexDetails.setIndexType(FlintIndexType.COVERING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.CREATE); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); visitPropertyList(ctx.propertyList()); return super.visitCreateCoveringIndexStatement(ctx); } + @Override + public Void visitCreateMaterializedViewStatement( + FlintSparkSqlExtensionsParser.CreateMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.CREATE); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); + visitPropertyList(ctx.propertyList()); + return super.visitCreateMaterializedViewStatement(ctx); + } + @Override public Void visitDropCoveringIndexStatement( FlintSparkSqlExtensionsParser.DropCoveringIndexStatementContext ctx) { - indexDetails.setDropIndex(true); - indexDetails.setIndexType(FlintIndexType.COVERING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DROP); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); return super.visitDropCoveringIndexStatement(ctx); } @Override public Void visitDropSkippingIndexStatement( FlintSparkSqlExtensionsParser.DropSkippingIndexStatementContext ctx) { - indexDetails.setDropIndex(true); - indexDetails.setIndexType(FlintIndexType.SKIPPING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DROP); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); return super.visitDropSkippingIndexStatement(ctx); } + @Override + public Void visitDropMaterializedViewStatement( + FlintSparkSqlExtensionsParser.DropMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DROP); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); + return super.visitDropMaterializedViewStatement(ctx); + } + + @Override + public Void visitDescribeCoveringIndexStatement( + FlintSparkSqlExtensionsParser.DescribeCoveringIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DESCRIBE); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); + return super.visitDescribeCoveringIndexStatement(ctx); + } + + @Override + public Void visitDescribeSkippingIndexStatement( + FlintSparkSqlExtensionsParser.DescribeSkippingIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DESCRIBE); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); + return super.visitDescribeSkippingIndexStatement(ctx); + } + + @Override + public Void visitDescribeMaterializedViewStatement( + FlintSparkSqlExtensionsParser.DescribeMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DESCRIBE); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); + return super.visitDescribeMaterializedViewStatement(ctx); + } + + @Override + public Void visitShowCoveringIndexStatement( + FlintSparkSqlExtensionsParser.ShowCoveringIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.SHOW); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); + return super.visitShowCoveringIndexStatement(ctx); + } + + @Override + public Void visitShowMaterializedViewStatement( + FlintSparkSqlExtensionsParser.ShowMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.SHOW); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + return super.visitShowMaterializedViewStatement(ctx); + } + + @Override + public Void visitRefreshCoveringIndexStatement( + FlintSparkSqlExtensionsParser.RefreshCoveringIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.REFRESH); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); + return super.visitRefreshCoveringIndexStatement(ctx); + } + + @Override + public Void visitRefreshSkippingIndexStatement( + FlintSparkSqlExtensionsParser.RefreshSkippingIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.REFRESH); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); + return super.visitRefreshSkippingIndexStatement(ctx); + } + + @Override + public Void visitRefreshMaterializedViewStatement( + FlintSparkSqlExtensionsParser.RefreshMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.REFRESH); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); + return super.visitRefreshMaterializedViewStatement(ctx); + } + @Override public Void visitPropertyList(FlintSparkSqlExtensionsParser.PropertyListContext ctx) { if (ctx != null) { @@ -180,7 +266,7 @@ public Void visitPropertyList(FlintSparkSqlExtensionsParser.PropertyListContext // https://github.com/apache/spark/blob/v3.5.0/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala#L35 to unescape string literal if (propertyKey(property.key).toLowerCase(Locale.ROOT).contains("auto_refresh")) { if (propertyValue(property.value).toLowerCase(Locale.ROOT).contains("true")) { - indexDetails.setAutoRefresh(true); + indexQueryDetailsBuilder.autoRefresh(true); } } }); diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml deleted file mode 100644 index 3a39b989a2..0000000000 --- a/spark/src/main/resources/job-metadata-index-mapping.yml +++ /dev/null @@ -1,25 +0,0 @@ ---- -## -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -## - -# Schema file for the .ql-job-metadata index -# Also "dynamic" is set to "false" so that other fields can be added. -dynamic: false -properties: - jobId: - type: text - fields: - keyword: - type: keyword - applicationId: - type: text - fields: - keyword: - type: keyword - resultIndex: - type: text - fields: - keyword: - type: keyword \ No newline at end of file diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml new file mode 100644 index 0000000000..682534d338 --- /dev/null +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -0,0 +1,44 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + version: + type: keyword + type: + type: keyword + state: + type: keyword + statementId: + type: keyword + applicationId: + type: keyword + sessionId: + type: keyword + sessionType: + type: keyword + error: + type: text + lang: + type: keyword + query: + type: text + dataSourceName: + type: keyword + submitTime: + type: date + format: strict_date_time||epoch_millis + jobId: + type: keyword + lastUpdateTime: + type: date + format: strict_date_time||epoch_millis + queryId: + type: keyword + excludeJobIds: + type: keyword diff --git a/spark/src/main/resources/job-metadata-index-settings.yml b/spark/src/main/resources/query_execution_request_settings.yml similarity index 88% rename from spark/src/main/resources/job-metadata-index-settings.yml rename to spark/src/main/resources/query_execution_request_settings.yml index be93f4645c..da2bf07bf1 100644 --- a/spark/src/main/resources/job-metadata-index-settings.yml +++ b/spark/src/main/resources/query_execution_request_settings.yml @@ -8,4 +8,4 @@ index: number_of_shards: "1" auto_expand_replicas: "0-2" - number_of_replicas: "0" \ No newline at end of file + number_of_replicas: "0" diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java new file mode 100644 index 0000000000..6bc40c009b --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -0,0 +1,602 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_ENABLED_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_LIMIT_SETTING; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; +import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import lombok.Getter; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.datasources.encryptor.EncryptorImpl; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.test.OpenSearchIntegTestCase; + +public class AsyncQueryExecutorServiceImplSpecTest extends OpenSearchIntegTestCase { + public static final String DATASOURCE = "mys3"; + + private ClusterService clusterService; + private org.opensearch.sql.common.setting.Settings pluginSettings; + private NodeClient client; + private DataSourceServiceImpl dataSourceService; + private StateStore stateStore; + private ClusterSettings clusterSettings; + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(TestSettingPlugin.class); + } + + public static class TestSettingPlugin extends Plugin { + @Override + public List> getSettings() { + return OpenSearchSettings.pluginSettings(); + } + } + + @Before + public void setup() { + clusterService = clusterService(); + clusterSettings = clusterService.getClusterSettings(); + pluginSettings = new OpenSearchSettings(clusterSettings); + client = (NodeClient) cluster().client(); + dataSourceService = createDataSourceService(); + dataSourceService.createDataSource( + new DataSourceMetadata( + DATASOURCE, + DataSourceType.S3GLUE, + ImmutableList.of(), + ImmutableMap.of( + "glue.auth.type", + "iam_role", + "glue.auth.role_arn", + "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", + "glue.indexstore.opensearch.uri", + "http://localhost:9200", + "glue.indexstore.opensearch.auth", + "noauth"), + null)); + stateStore = new StateStore(client, clusterService); + createIndex(SPARK_RESPONSE_BUFFER_INDEX_NAME); + } + + @After + public void clean() { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()) + .get(); + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().putNull(SPARK_EXECUTION_SESSION_LIMIT_SETTING.getKey()).build()) + .get(); + } + + @Test + public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // disable session + enableSession(false); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); + emrsClient.startJobRunCalled(1); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals("RUNNING", asyncQueryResults.getStatus()); + emrsClient.getJobRunResultCalled(1); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void createAsyncQueryCreateJobWithCorrectParameters() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + enableSession(false); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + String params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertNull(response.getSessionId()); + assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); + assertFalse( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertFalse( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); + + // enable session + enableSession(true); + response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); + assertTrue( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertTrue( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); + } + + @Test + public void withSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(response.getSessionId()); + Optional statementModel = + getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); + assertTrue(statementModel.isPresent()); + assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + } + + @Test + public void reuseSessionWhenCreateAsyncQuery() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); + + assertEquals(first.getSessionId(), second.getSessionId()); + assertNotEquals(first.getQueryId(), second.getQueryId()); + // one session doc. + assertEquals( + 1, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", SESSION_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + // two statement docs has same sessionId. + assertEquals( + 2, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + + Optional firstModel = + getStatement(stateStore, DATASOURCE).apply(first.getQueryId()); + assertTrue(firstModel.isPresent()); + assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); + assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); + assertEquals(first.getQueryId(), firstModel.get().getQueryId()); + Optional secondModel = + getStatement(stateStore, DATASOURCE).apply(second.getQueryId()); + assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); + assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); + assertEquals(second.getQueryId(), secondModel.get().getQueryId()); + } + + @Test + public void batchQueryHasTimeout() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + enableSession(false); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + + assertEquals(120L, (long) emrsClient.getJobRequest().executionTimeout()); + } + + @Test + public void interactiveQueryNoTimeout() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertEquals(0L, (long) emrsClient.getJobRequest().executionTimeout()); + } + + @Test + public void datasourceWithBasicAuth() { + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "basicauth"); + properties.put("glue.indexstore.opensearch.auth.username", "username"); + properties.put("glue.indexstore.opensearch.auth.password", "password"); + + dataSourceService.createDataSource( + new DataSourceMetadata( + "mybasicauth", DataSourceType.S3GLUE, ImmutableList.of(), properties, null)); + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null)); + String params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertTrue(params.contains(String.format("--conf spark.datasource.flint.auth=basic"))); + assertTrue( + params.contains(String.format("--conf spark.datasource.flint.auth.username=username"))); + assertTrue( + params.contains(String.format("--conf spark.datasource.flint.auth.password=password"))); + } + + @Test + public void withSessionCreateAsyncQueryFailed() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("myselect 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(response.getSessionId()); + Optional statementModel = + getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); + assertTrue(statementModel.isPresent()); + assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); + + // 2. fetch async query result. not result write to SPARK_RESPONSE_BUFFER_INDEX_NAME yet. + // mock failed statement. + StatementModel submitted = statementModel.get(); + StatementModel mocked = + StatementModel.builder() + .version("1.0") + .statementState(submitted.getStatementState()) + .statementId(submitted.getStatementId()) + .sessionId(submitted.getSessionId()) + .applicationId(submitted.getApplicationId()) + .jobId(submitted.getJobId()) + .langType(submitted.getLangType()) + .datasourceName(submitted.getDatasourceName()) + .query(submitted.getQuery()) + .queryId(submitted.getQueryId()) + .submitTime(submitted.getSubmitTime()) + .error("mock error") + .seqNo(submitted.getSeqNo()) + .primaryTerm(submitted.getPrimaryTerm()) + .build(); + updateStatementState(stateStore, DATASOURCE).apply(mocked, StatementState.FAILED); + + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals(StatementState.FAILED.getState(), asyncQueryResults.getStatus()); + assertEquals("mock error", asyncQueryResults.getError()); + } + + // https://github.com/opensearch-project/sql/issues/2344 + @Test + public void createSessionMoreThanLimitFailed() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + // only allow one session in domain. + setSessionLimit(1); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + setSessionState(first.getSessionId(), SessionState.RUNNING); + + // 2. create async query without session. + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null))); + assertEquals( + "The maximum number of active sessions can be supported is 1", exception.getMessage()); + } + + // https://github.com/opensearch-project/sql/issues/2360 + @Test + public void recreateSessionIfNotReady() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // set sessionState to FAIL + setSessionState(first.getSessionId(), SessionState.FAIL); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); + + assertNotEquals(first.getSessionId(), second.getSessionId()); + + // set sessionState to FAIL + setSessionState(second.getSessionId(), SessionState.DEAD); + + // 3. reuse session id + CreateAsyncQueryResponse third = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, second.getSessionId())); + assertNotEquals(second.getSessionId(), third.getSessionId()); + } + + @Test + public void submitQueryInInvalidSessionThrowException() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + SessionId sessionId = SessionId.newSessionId(DATASOURCE); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId()))); + assertEquals("no session found. " + sessionId, exception.getMessage()); + } + + private DataSourceServiceImpl createDataSourceService() { + String masterKey = "a57d991d9b573f75b9bba1df"; + DataSourceMetadataStorage dataSourceMetadataStorage = + new OpenSearchDataSourceMetadataStorage( + client, clusterService, new EncryptorImpl(masterKey)); + return new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new GlueDataSourceFactory(pluginSettings)) + .build(), + dataSourceMetadataStorage, + meta -> {}); + } + + private AsyncQueryExecutorService createAsyncQueryExecutorService( + EMRServerlessClient emrServerlessClient) { + StateStore stateStore = new StateStore(client, clusterService); + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, + this.dataSourceService, + new DataSourceUserAuthorizationHelperImpl(client), + jobExecutionResponseReader, + new FlintIndexMetadataReaderImpl(client), + client, + new SessionManager(stateStore, emrServerlessClient, pluginSettings)); + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + this::sparkExecutionEngineConfig); + } + + public static class LocalEMRSClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + private int getJobResult = 0; + + @Getter private StartJobRequest jobRequest; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + jobRequest = startJobRequest; + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + getJobResult++; + JobRun jobRun = new JobRun(); + jobRun.setState("RUNNING"); + return new GetJobRunResult().withJobRun(jobRun); + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return new CancelJobRunResult().withJobRunId(jobId); + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + + public void getJobRunResultCalled(int expectedTimes) { + assertEquals(expectedTimes, getJobResult); + } + } + + public SparkExecutionEngineConfig sparkExecutionEngineConfig() { + return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); + } + + public void enableSession(boolean enabled) { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder() + .put(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey(), enabled) + .build()) + .get(); + } + + public void setSessionLimit(long limit) { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().put(SPARK_EXECUTION_SESSION_LIMIT_SETTING.getKey(), limit).build()) + .get(); + } + + int search(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(DATASOURCE_TO_REQUEST_INDEX.apply(DATASOURCE)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + SearchResponse searchResponse = client.search(searchRequest).actionGet(); + + return searchResponse.getHits().getHits().length; + } + + void setSessionState(String sessionId, SessionState sessionState) { + Optional model = getSession(stateStore, DATASOURCE).apply(sessionId); + SessionModel updated = + updateSessionState(stateStore, DATASOURCE).apply(model.get(), sessionState); + assertEquals(sessionState, updated.getSessionState()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 01bccd9030..2ed316795f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; @@ -29,6 +30,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -47,6 +49,7 @@ public class AsyncQueryExecutorServiceImplTest { private AsyncQueryExecutorService jobExecutorService; @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); @BeforeEach void setUp() { @@ -78,11 +81,12 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID, null)); + .storeJobMetadata( + new AsyncQueryJobMetadata(QUERY_ID, "00fd775baqpu4g0p", EMR_JOB_ID, null)); verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(); verify(sparkQueryDispatcher, times(1)) .dispatch( @@ -93,7 +97,7 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME)); - Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); + Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); } @Test @@ -107,7 +111,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( @@ -139,11 +143,13 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { @Test void testGetAsyncQueryResultsWithInProgressJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -157,11 +163,13 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -208,9 +216,11 @@ void testCancelJobWithJobNotFound() { @Test void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); when(sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(EMR_JOB_ID); String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index 7288fd3fc2..de0caf5589 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -5,242 +5,70 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService.JOB_METADATA_INDEX; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import java.util.Optional; -import org.apache.lucene.search.TotalHits; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; -import org.mockito.ArgumentMatchers; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.test.OpenSearchIntegTestCase; -@ExtendWith(MockitoExtension.class) -public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest { +public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest + extends OpenSearchIntegTestCase { - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private Client client; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private ClusterService clusterService; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private SearchResponse searchResponse; - - @Mock private ActionFuture searchResponseActionFuture; - @Mock private ActionFuture createIndexResponseActionFuture; - @Mock private ActionFuture indexResponseActionFuture; - @Mock private IndexResponse indexResponse; - @Mock private SearchHit searchHit; - - @InjectMocks + public static final String DS_NAME = "mys3"; + private static final String MOCK_SESSION_ID = "sessionId"; + private static final String MOCK_RESULT_INDEX = "resultIndex"; private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; - @Test - public void testStoreJobMetadata() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - - this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithOutCreatingIndex() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.TRUE); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - - this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); - - Mockito.verify(client.admin().indices(), Mockito.times(0)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithException() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())) - .thenThrow(new RuntimeException("error while indexing")); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "java.lang.RuntimeException: error while indexing", runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithIndexCreationFailed() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "Internal server error while creating.ql-job-metadata index:: " - + "Index creation is not acknowledged.", - runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); - } - - @Test - public void testStoreJobMetadataFailedWithNotFoundResponse() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "Saving job metadata information failed with result : not_found", - runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testGetJobMetadata() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); - Mockito.when(searchResponse.getHits()) - .thenReturn( - new SearchHits( - new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null); - Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); - - Optional jobMetadataOptional = - opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); - Assertions.assertTrue(jobMetadataOptional.isPresent()); - Assertions.assertEquals(EMR_JOB_ID, jobMetadataOptional.get().getJobId()); - Assertions.assertEquals(EMRS_APPLICATION_ID, jobMetadataOptional.get().getApplicationId()); + @Before + public void setup() { + opensearchJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService( + new StateStore(client(), clusterService())); } @Test - public void testGetJobMetadataWith404SearchResponse() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.NOT_FOUND); - - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); - Assertions.assertEquals( - "Fetching job metadata information failed with status : NOT_FOUND", - runtimeException.getMessage()); - } - - @Test - public void testGetJobMetadataWithParsingFailed() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); - Mockito.when(searchResponse.getHits()) - .thenReturn( - new SearchHits( - new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); - Mockito.when(searchHit.getSourceAsString()).thenReturn("..tesJOBs"); - - Assertions.assertThrows( - RuntimeException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + public void testStoreJobMetadata() { + AsyncQueryJobMetadata expected = + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMR_JOB_ID, + EMRS_APPLICATION_ID, + MOCK_RESULT_INDEX); + + opensearchJobMetadataStorageService.storeJobMetadata(expected); + Optional actual = + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + + assertTrue(actual.isPresent()); + assertEquals(expected, actual.get()); + assertFalse(actual.get().isDropIndexQuery()); + assertNull(actual.get().getSessionId()); } @Test - public void testGetJobMetadataWithNoIndex() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - - Optional jobMetadata = - opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); - - Assertions.assertFalse(jobMetadata.isPresent()); + public void testStoreJobMetadataWithResultExtraData() { + AsyncQueryJobMetadata expected = + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMR_JOB_ID, + EMRS_APPLICATION_ID, + true, + MOCK_RESULT_INDEX, + MOCK_SESSION_ID); + + opensearchJobMetadataStorageService.storeJobMetadata(expected); + Optional actual = + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + + assertTrue(actual.isPresent()); + assertEquals(expected, actual.get()); + assertTrue(actual.get().isDropIndexQuery()); + assertEquals("resultIndex", actual.get().getResultIndex()); + assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index abae0377a2..3a0d8fc56d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -16,4 +16,6 @@ public class TestConstants { public static final String EMRS_JOB_NAME = "job_name"; public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; public static final String TEST_CLUSTER_NAME = "TEST_CLUSTER"; + public static final String MOCK_SESSION_ID = "s-0123456"; + public static final String MOCK_STATEMENT_ID = "st-0123456"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index ab9761da36..743274d46c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,19 +5,27 @@ package org.opensearch.sql.spark.dispatcher; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_STATEMENT_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; @@ -34,13 +42,15 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -49,13 +59,17 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; -import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.*; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.flint.FlintIndexType; @@ -71,13 +85,25 @@ public class SparkQueryDispatcherTest { @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @Mock private FlintIndexMetadataReader flintIndexMetadataReader; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) + @Mock(answer = RETURNS_DEEP_STUBS) private Client openSearchClient; @Mock private FlintIndexMetadata flintIndexMetadata; + @Mock private SessionManager sessionManager; + + @Mock(answer = RETURNS_DEEP_STUBS) + private Session session; + + @Mock(answer = RETURNS_DEEP_STUBS) + private Statement statement; + private SparkQueryDispatcher sparkQueryDispatcher; + private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); + + @Captor ArgumentCaptor startJobRequestArgumentCaptor; + @BeforeEach void setUp() { sparkQueryDispatcher = @@ -87,7 +113,8 @@ void setUp() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); } @Test @@ -95,20 +122,23 @@ void testDispatchSelectQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -125,23 +155,18 @@ void testDispatchSelectQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -152,21 +177,24 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "basic", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); + put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "basicauth", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -183,24 +211,18 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "basicauth", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -211,19 +233,22 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "noauth", + new HashMap<>() { + { + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -240,52 +265,130 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); } + @Test + void testDispatchSelectQueryCreateNewSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(session).when(sessionManager).createSession(any()); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).getSession(any()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryReuseSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, MOCK_SESSION_ID); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.of(session)) + .when(sessionManager) + .getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); + when(session.isReady()).thenReturn(true); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryInvalidSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid"); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.empty()).when(sessionManager).getSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testDispatchSelectQueryFailedCreateSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doThrow(RuntimeException.class).when(sessionManager).createSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + Assertions.assertThrows( + RuntimeException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + } + @Test void testDispatchIndexQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); - tags.put("table", "http_logs"); - tags.put("index", "elb_and_requestUri"); + tags.put("index", "flint_my_glue_default_http_logs_elb_and_requesturi_index"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("schema", "default"); + tags.put("type", JobType.STREAMING.getText()); String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), + sparkSubmitParameters, tags, true, any()))) @@ -302,24 +405,18 @@ void testDispatchIndexQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), - tags, - true, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -330,21 +427,23 @@ void testDispatchWithPPLQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - + tags.put("type", JobType.BATCH.getText()); String query = "source = my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -361,50 +460,165 @@ void testDispatchWithPPLQuery() { LangType.PPL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } + + @Test + void testDispatchQueryWithoutATableAndDataSourceName() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("type", JobType.BATCH.getText()); + String query = "show tables"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); + when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, - any())); + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); } @Test - void testDispatchQueryWithoutATableAndDataSourceName() { + void testDispatchIndexQueryWithoutADatasourceName() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); + tags.put("index", "flint_my_glue_default_http_logs_elb_and_requesturi_index"); tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("type", JobType.STREAMING.getText()); + String query = + "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } - String query = "show tables"; + @Test + void testDispatchMaterializedViewQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("index", "flint_mv_1"); + tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("type", JobType.STREAMING.getText()); + String query = + "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); when(emrServerlessClient.startJobRun( new StartJobRequest( query, - "TEST_CLUSTER:non-index-query", + "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, - false, + true, any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); @@ -419,56 +633,100 @@ void testDispatchQueryWithoutATableAndDataSourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } + + @Test + void testDispatchShowMVQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); + String query = "SHOW MATERIALIZED VIEW IN mys3.default"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); + when(emrServerlessClient.startJobRun( new StartJobRequest( query, - "TEST_CLUSTER:non-index-query", + "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, - any())); + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); } @Test - void testDispatchIndexQueryWithoutADatasourceName() { + void testRefreshIndexQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); - tags.put("table", "http_logs"); - tags.put("index", "elb_and_requestUri"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("schema", "default"); - - String query = - "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" - + " (auto_refresh = true)"; + String query = "REFRESH SKIPPING INDEX ON my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), + sparkSubmitParameters, tags, - true, + false, any()))) .thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); @@ -483,24 +741,72 @@ void testDispatchIndexQueryWithoutADatasourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } + + @Test + void testDispatchDescribeIndexQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); + String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); + when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), + sparkSubmitParameters, tags, - true, - any())); + false, + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -557,10 +863,68 @@ void testCancelJob() { new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String jobId = + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); + } + + @Test + void testCancelQueryWithSession() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + doNothing().when(statement).cancel(); + + String queryId = sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + verify(statement, times(1)).cancel(); + Assertions.assertEquals(MOCK_STATEMENT_ID, queryId); + } + + @Test + void testCancelQueryWithInvalidSession() { + doReturn(Optional.empty()).when(sessionManager).getSession(new SessionId("invalid")); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(session); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithInvalidStatementId() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(statement); + Assertions.assertEquals( + "no statement found. " + new StatementId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithNoSessionId() { + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @Test @@ -571,10 +935,63 @@ void testGetQueryResponse() { // simulate result index is not created yet when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(new JSONObject()); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + Assertions.assertEquals("PENDING", result.get("status")); + } + + @Test + void testGetQueryResponseWithSession() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + when(statement.getStatementModel().getError()).thenReturn("mock error"); + doReturn(StatementState.WAITING).when(statement).getStatementState(); + + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); - Assertions.assertEquals("PENDING", result.get("status")); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals("waiting", result.get("status")); + } + + @Test + void testGetQueryResponseWithInvalidSession() { + doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no session found. " + new SessionId(MOCK_SESSION_ID), exception.getMessage()); + } + + @Test + void testGetQueryResponseWithStatementNotExist() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.empty()).when(session).get(any()); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage()); } @Test @@ -586,7 +1003,8 @@ void testGetQueryResponseWithSuccess() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); JSONObject queryResult = new JSONObject(); Map resultMap = new HashMap<>(); resultMap.put(STATUS_FIELD, "SUCCESS"); @@ -594,9 +1012,7 @@ void testGetQueryResponseWithSuccess() { queryResult.put(DATA_FIELD, resultMap); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(queryResult); - JSONObject result = - sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null); Assertions.assertEquals( new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet()); @@ -623,14 +1039,21 @@ void testGetQueryResponseOfDropIndex() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); String jobId = new SparkQueryDispatcher.DropIndexResult(JobRunState.SUCCESS.toString()).toJobId(); JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null)); + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMRS_APPLICATION_ID, + jobId, + true, + null, + null)); verify(jobExecutionResponseReader, times(0)) .getResultFromOpensearchIndex(anyString(), anyString()); Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); @@ -639,13 +1062,15 @@ void testGetQueryResponseOfDropIndex() { @Test void testDropIndexQuery() throws ExecutionException, InterruptedException { String query = "DROP INDEX size_year ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "size_year", - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.COVERING))) + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() + .indexName("size_year") + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.COVERING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); // auto_refresh == true @@ -674,15 +1099,7 @@ void testDropIndexQuery() throws ExecutionException, InterruptedException { TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - "size_year", - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.COVERING)); - + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -692,13 +1109,14 @@ void testDropIndexQuery() throws ExecutionException, InterruptedException { @Test void testDropSkippingIndexQuery() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING))) + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.SKIPPING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); when(flintIndexMetadata.isAutoRefresh()).thenReturn(true); @@ -725,14 +1143,7 @@ void testDropSkippingIndexQuery() throws ExecutionException, InterruptedExceptio TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -743,13 +1154,14 @@ void testDropSkippingIndexQuery() throws ExecutionException, InterruptedExceptio void testDropSkippingIndexQueryAutoRefreshFalse() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING))) + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.SKIPPING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.isAutoRefresh()).thenReturn(false); @@ -770,14 +1182,7 @@ void testDropSkippingIndexQueryAutoRefreshFalse() TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(0)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -788,13 +1193,14 @@ void testDropSkippingIndexQueryAutoRefreshFalse() void testDropSkippingIndexQueryDeleteIndexException() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING))) + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.SKIPPING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.isAutoRefresh()).thenReturn(false); @@ -816,14 +1222,7 @@ void testDropSkippingIndexQueryDeleteIndexException() TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(0)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.FAILED.toString(), dropIndexResult.getStatus()); @@ -833,6 +1232,52 @@ void testDropSkippingIndexQueryDeleteIndexException() Assertions.assertTrue(dispatchQueryResponse.isDropIndexQuery()); } + @Test + void testDropMVQuery() throws ExecutionException, InterruptedException { + String query = "DROP MATERIALIZED VIEW mv_1"; + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() + .mvName("mv_1") + .indexQueryActionType(IndexQueryActionType.DROP) + .fullyQualifiedTableName(null) + .indexType(FlintIndexType.MATERIALIZED_VIEW) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) + .thenReturn(flintIndexMetadata); + when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); + // auto_refresh == true + when(flintIndexMetadata.isAutoRefresh()).thenReturn(true); + + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + + AcknowledgedResponse acknowledgedResponse = mock(AcknowledgedResponse.class); + when(openSearchClient.admin().indices().delete(any()).get()).thenReturn(acknowledgedResponse); + when(acknowledgedResponse.isAcknowledged()).thenReturn(true); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); + SparkQueryDispatcher.DropIndexResult dropIndexResult = + SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); + Assertions.assertTrue(dispatchQueryResponse.isDropIndexQuery()); + } + @Test void testDispatchQueryWithExtraSparkSubmitParameters() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); @@ -905,8 +1350,8 @@ private String constructExpectedSparkSubmitParameterString( + " --conf" + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " - + authParamConfigBuilder - + " --conf spark.flint.datasource.name=my_glue "; + + " --conf spark.flint.datasource.name=my_glue " + + authParamConfigBuilder; } private String withStructuredStreaming(String parameters) { @@ -997,6 +1442,29 @@ private DispatchQueryRequest constructDispatchQueryRequest( langType, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME, - extraParameters); + extraParameters, + null); + } + + private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, String sessionId) { + return new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME, + null, + sessionId); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadata() { + return new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( + String statementId, String sessionId) { + return new AsyncQueryJobMetadata( + new AsyncQueryId(statementId), EMRS_APPLICATION_ID, EMR_JOB_ID, false, null, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java new file mode 100644 index 0000000000..14ccaf7708 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import java.util.HashMap; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.test.OpenSearchIntegTestCase; + +/** mock-maker-inline does not work with OpenSearchTestCase. */ +public class InteractiveSessionTest extends OpenSearchIntegTestCase { + + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); + + private TestEMRServerlessClient emrsClient; + private StartJobRequest startJobRequest; + private StateStore stateStore; + + @Before + public void setup() { + emrsClient = new TestEMRServerlessClient(); + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new StateStore(client(), clusterService()); + } + + @After + public void clean() { + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + } + + @Test + public void openCloseSession() { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(SessionId.newSessionId(DS_NAME)) + .stateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + + // open session + TestSession testSession = testSession(session, stateStore); + testSession + .open(createSessionRequest()) + .assertSessionState(NOT_STARTED) + .assertAppId("appId") + .assertJobId("jobId"); + emrsClient.startJobRunCalled(1); + + // close session + testSession.close(); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void openSessionFailedConflict() { + SessionId sessionId = SessionId.newSessionId(DS_NAME); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .stateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + session.open(createSessionRequest()); + + InteractiveSession duplicateSession = + InteractiveSession.builder() + .sessionId(sessionId) + .stateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + IllegalStateException exception = + assertThrows( + IllegalStateException.class, () -> duplicateSession.open(createSessionRequest())); + assertEquals("session already exist. " + sessionId, exception.getMessage()); + } + + @Test + public void closeNotExistSession() { + SessionId sessionId = SessionId.newSessionId(DS_NAME); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .stateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + session.open(createSessionRequest()); + + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); + assertEquals("session does not exist. " + sessionId, exception.getMessage()); + emrsClient.cancelJobRunCalled(0); + } + + @Test + public void sessionManagerCreateSession() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + + TestSession testSession = testSession(session, stateStore); + testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + } + + @Test + public void sessionManagerGetSession() { + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); + Session session = sessionManager.createSession(createSessionRequest()); + + Optional managerSession = sessionManager.getSession(session.getSessionId()); + assertTrue(managerSession.isPresent()); + assertEquals(session.getSessionId(), managerSession.get().getSessionId()); + } + + @Test + public void sessionManagerGetSessionNotExist() { + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); + + Optional managerSession = + sessionManager.getSession(SessionId.newSessionId("no-exist")); + assertTrue(managerSession.isEmpty()); + } + + @RequiredArgsConstructor + static class TestSession { + private final Session session; + private final StateStore stateStore; + + public static TestSession testSession(Session session, StateStore stateStore) { + return new TestSession(session, stateStore); + } + + public TestSession assertSessionState(SessionState expected) { + assertEquals(expected, session.getSessionModel().getSessionState()); + + Optional sessionStoreState = + getSession(stateStore, DS_NAME).apply(session.getSessionModel().getId()); + assertTrue(sessionStoreState.isPresent()); + assertEquals(expected, sessionStoreState.get().getSessionState()); + + return this; + } + + public TestSession assertAppId(String expected) { + assertEquals(expected, session.getSessionModel().getApplicationId()); + return this; + } + + public TestSession assertJobId(String expected) { + assertEquals(expected, session.getSessionModel().getJobId()); + return this; + } + + public TestSession open(CreateSessionRequest req) { + session.open(req); + return this; + } + + public TestSession close() { + session.close(); + return this; + } + } + + public static CreateSessionRequest createSessionRequest() { + return new CreateSessionRequest( + "jobName", + "appId", + "arn", + SparkSubmitParameters.Builder.builder(), + new HashMap<>(), + "resultIndex", + DS_NAME); + } + + public static class TestEMRServerlessClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + return null; + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return null; + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java new file mode 100644 index 0000000000..3546a874d9 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.StateStore; + +@ExtendWith(MockitoExtension.class) +public class SessionManagerTest { + @Mock private StateStore stateStore; + @Mock private EMRServerlessClient emrClient; + + @Test + public void sessionEnable() { + Assertions.assertTrue( + new SessionManager(stateStore, emrClient, sessionSetting(true)).isEnabled()); + Assertions.assertFalse( + new SessionManager(stateStore, emrClient, sessionSetting(false)).isEnabled()); + } + + public static Settings sessionSetting(boolean enabled) { + Map settings = new HashMap<>(); + settings.put(Settings.Key.SPARK_EXECUTION_SESSION_ENABLED, enabled); + settings.put(Settings.Key.SPARK_EXECUTION_SESSION_LIMIT, 100); + return settings(settings); + } + + public static Settings settings(Map settings) { + return new Settings() { + @Override + public T getSettingValue(Key key) { + return (T) settings.get(key); + } + + @Override + public List getSettings() { + return (List) settings; + } + }; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java new file mode 100644 index 0000000000..a987c80d59 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionStateTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionState.fromString("invalid")); + assertEquals("Invalid session state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java new file mode 100644 index 0000000000..a2ab43e709 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionTypeTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionType.fromString("invalid")); + assertEquals("Invalid session type: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java new file mode 100644 index 0000000000..b7af1123ba --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class StatementStateTest { + @Test + public void invalidStatementState() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> StatementState.fromString("invalid")); + Assertions.assertEquals("Invalid statement state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java new file mode 100644 index 0000000000..29020f2496 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -0,0 +1,443 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; +import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; +import static org.opensearch.sql.spark.execution.statement.StatementState.RUNNING; +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; +import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.test.OpenSearchIntegTestCase; + +public class StatementTest extends OpenSearchIntegTestCase { + + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); + + private StateStore stateStore; + private InteractiveSessionTest.TestEMRServerlessClient emrsClient = + new InteractiveSessionTest.TestEMRServerlessClient(); + + @Before + public void setup() { + stateStore = new StateStore(client(), clusterService()); + } + + @After + public void clean() { + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + } + + @Test + public void openThenCancelStatement() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + + // submit statement + TestStatement testStatement = testStatement(st, stateStore); + testStatement + .open() + .assertSessionState(WAITING) + .assertStatementId(new StatementId("statementId")); + + // close statement + testStatement.cancel().assertSessionState(CANCELLED); + } + + @Test + public void openFailedBecauseConflict() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + // open statement with same statement id + Statement dupSt = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + IllegalStateException exception = assertThrows(IllegalStateException.class, dupSt::open); + assertEquals("statement already exist. statementId=statementId", exception.getMessage()); + } + + @Test + public void cancelNotExistStatement() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + client().delete(new DeleteRequest(indexName, stId.getId())).actionGet(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("cancel statement failed. no statement found. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelFailedBecauseOfConflict() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + StatementModel running = + updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), CANCELLED); + + assertEquals(StatementState.CANCELLED, running.getStatementState()); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format( + "cancel statement failed. current statementState: CANCELLED " + "statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelSuccessStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = createStatement(stId); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), + StatementState.SUCCESS, + model.getSeqNo(), + model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in success state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelFailedStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = createStatement(stId); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), + StatementState.FAILED, + model.getSeqNo(), + model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in failed state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelCancelledStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = createStatement(stId); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), CANCELLED, model.getSeqNo(), model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in cancelled state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelRunningStatementSuccess() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + + // submit statement + TestStatement testStatement = testStatement(st, stateStore); + testStatement + .open() + .assertSessionState(WAITING) + .assertStatementId(new StatementId("statementId")); + + testStatement.run(); + + // close statement + testStatement.cancel().assertSessionState(CANCELLED); + } + + @Test + public void submitStatementInRunningSession() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + + // App change state to running + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + + StatementId statementId = session.submit(queryRequest()); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void submitStatementInNotStartedState() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + + StatementId statementId = session.submit(queryRequest()); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void failToSubmitStatementInDeadState() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); + + IllegalStateException exception = + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " dead", + exception.getMessage()); + } + + @Test + public void failToSubmitStatementInFailState() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); + + IllegalStateException exception = + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " fail", + exception.getMessage()); + } + + @Test + public void newStatementFieldAssert() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + StatementId statementId = session.submit(queryRequest()); + Optional statement = session.get(statementId); + + assertTrue(statement.isPresent()); + assertEquals(session.getSessionId(), statement.get().getSessionId()); + assertEquals("appId", statement.get().getApplicationId()); + assertEquals("jobId", statement.get().getJobId()); + assertEquals(statementId, statement.get().getStatementId()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(LangType.SQL, statement.get().getLangType()); + assertEquals("select 1", statement.get().getQuery()); + } + + @Test + public void failToSubmitStatementInDeletedSession() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + + // other's delete session + client() + .delete(new DeleteRequest(indexName, session.getSessionId().getSessionId())) + .actionGet(); + + IllegalStateException exception = + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); + } + + @Test + public void getStatementSuccess() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + // App change state to running + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + StatementId statementId = session.submit(queryRequest()); + + Optional statement = session.get(statementId); + assertTrue(statement.isPresent()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(statementId, statement.get().getStatementId()); + } + + @Test + public void getStatementNotExist() { + Session session = + new SessionManager(stateStore, emrsClient, sessionSetting(false)) + .createSession(createSessionRequest()); + // App change state to running + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + + Optional statement = session.get(StatementId.newStatementId("not-exist-id")); + assertFalse(statement.isPresent()); + } + + @RequiredArgsConstructor + static class TestStatement { + private final Statement st; + private final StateStore stateStore; + + public static TestStatement testStatement(Statement st, StateStore stateStore) { + return new TestStatement(st, stateStore); + } + + public TestStatement assertSessionState(StatementState expected) { + assertEquals(expected, st.getStatementModel().getStatementState()); + + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementState()); + + return this; + } + + public TestStatement assertStatementId(StatementId expected) { + assertEquals(expected, st.getStatementModel().getStatementId()); + + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementId()); + return this; + } + + public TestStatement open() { + st.open(); + return this; + } + + public TestStatement cancel() { + st.cancel(); + return this; + } + + public TestStatement run() { + StatementModel model = + updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), RUNNING); + st.setStatementModel(model); + return this; + } + } + + private QueryRequest queryRequest() { + return new QueryRequest(AsyncQueryId.newAsyncQueryId(DS_NAME), LangType.SQL, "select 1"); + } + + private Statement createStatement(StatementId stId) { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + return st; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java index b0c8491b0b..4d809c31dc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java @@ -25,7 +25,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; @ExtendWith(MockitoExtension.class) public class FlintIndexMetadataReaderImplTest { @@ -44,12 +45,12 @@ void testGetJobIdFromFlintSkippingIndexMetadata() { FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + IndexQueryDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.SKIPPING) + .build()); Assertions.assertEquals("00fdmvv9hp8u0o0q", indexMetadata.getJobId()); } @@ -64,12 +65,13 @@ void testGetJobIdFromFlintCoveringIndexMetadata() { FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "cv1", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.COVERING)); + IndexQueryDetails.builder() + .indexName("cv1") + .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.COVERING) + .build()); Assertions.assertEquals("00fdmvv9hp8u0o0q", indexMetadata.getJobId()); } @@ -86,34 +88,17 @@ void testGetJobIDWithNPEException() { IllegalArgumentException.class, () -> flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "cv1", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.COVERING))); + IndexQueryDetails.builder() + .indexName("cv1") + .fullyQualifiedTableName( + new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.COVERING) + .build())); Assertions.assertEquals("Provided Index doesn't exist", illegalArgumentException.getMessage()); } - @SneakyThrows - @Test - void testGetJobIdFromUnsupportedIndex() { - FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); - UnsupportedOperationException unsupportedOperationException = - Assertions.assertThrows( - UnsupportedOperationException.class, - () -> - flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "cv1", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.MATERIALIZED_VIEW))); - Assertions.assertEquals( - "Unsupported Index Type : MATERIALIZED_VIEW", unsupportedOperationException.getMessage()); - } - @SneakyThrows public void mockNodeClientIndicesMappings(String indexName, String mappings) { GetMappingsResponse mockResponse = mock(GetMappingsResponse.class); diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java deleted file mode 100644 index 46fa4f7dbe..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.flint; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import org.junit.jupiter.api.Test; -import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; - -public class IndexDetailsTest { - @Test - public void skippingIndexName() { - assertEquals( - "flint_mys3_default_http_logs_skipping_index", - new IndexDetails( - "invalid", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING) - .openSearchIndexName()); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/IndexQueryDetailsTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/IndexQueryDetailsTest.java new file mode 100644 index 0000000000..e725ddc21e --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/IndexQueryDetailsTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; + +public class IndexQueryDetailsTest { + @Test + public void skippingIndexName() { + assertEquals( + "flint_mys3_default_http_logs_skipping_index", + IndexQueryDetails.builder() + .indexName("invalid") + .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .indexQueryActionType(IndexQueryActionType.DROP) + .indexType(FlintIndexType.SKIPPING) + .build() + .openSearchIndexName()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java new file mode 100644 index 0000000000..dd634d6055 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import java.io.IOException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +public class CreateAsyncQueryRequestTest { + + @Test + public void fromXContent() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("my_glue", queryRequest.getDatasource()); + Assertions.assertEquals(LangType.SQL, queryRequest.getLang()); + Assertions.assertEquals("select 1", queryRequest.getQuery()); + } + + @Test + public void fromXContentWithSessionId() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\",\n" + + " \"sessionId\": \"00fdjevgkf12s00q\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("00fdjevgkf12s00q", queryRequest.getSessionId()); + } + + private XContentParser xContentParser(String request) throws IOException { + return XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, request); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 8599e4b88e..36060d3850 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import java.util.HashSet; import org.junit.jupiter.api.Assertions; @@ -61,7 +62,7 @@ public void testDoExecute() { CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) - .thenReturn(new CreateAsyncQueryResponse("123")); + .thenReturn(new CreateAsyncQueryResponse("123", null)); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); CreateAsyncQueryActionResponse createAsyncQueryActionResponse = @@ -70,6 +71,24 @@ public void testDoExecute() { "{\n" + " \"queryId\": \"123\"\n" + "}", createAsyncQueryActionResponse.getResult()); } + @Test + public void testDoExecuteWithSessionId() { + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest( + "source = my_glue.default.alb_logs", "my_glue", LangType.SQL, MOCK_SESSION_ID); + CreateAsyncQueryActionRequest request = + new CreateAsyncQueryActionRequest(createAsyncQueryRequest); + when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) + .thenReturn(new CreateAsyncQueryResponse("123", MOCK_SESSION_ID)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + CreateAsyncQueryActionResponse createAsyncQueryActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + " \"queryId\": \"123\",\n" + " \"sessionId\": \"s-0123456\"\n" + "}", + createAsyncQueryActionResponse.getResult()); + } + @Test public void testDoExecuteWithException() { CreateAsyncQueryRequest createAsyncQueryRequest = diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 21a213c7c2..34f10b0083 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -63,7 +63,7 @@ public void setUp() { public void testDoExecute() { GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null); + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); @@ -89,6 +89,7 @@ public void testDoExecuteWithSuccessResponse() { Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Smith", "age", 30))), + null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index af892fa097..c86d7656d6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.utils; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; +import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; import lombok.Getter; @@ -14,7 +15,9 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; +import org.opensearch.sql.spark.flint.FlintIndexType; @ExtendWith(MockitoExtension.class) public class SQLQueryUtilsTest { @@ -24,13 +27,13 @@ void testExtractionOfTableNameFromSQLQueries() { String sqlQuery = "select * from my_glue.default.http_logs"; FullyQualifiedTableName fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "select * from my_glue.db.http_logs"; - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("db", fullyQualifiedTableName.getSchemaName()); @@ -38,28 +41,28 @@ void testExtractionOfTableNameFromSQLQueries() { sqlQuery = "select * from my_glue.http_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("my_glue", fullyQualifiedTableName.getSchemaName()); Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "select * from http_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "DROP TABLE myS3.default.alb_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "DESCRIBE TABLE myS3.default.alb_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); @@ -72,7 +75,7 @@ void testExtractionOfTableNameFromSQLQueries() { + "STORED AS file_format\n" + "LOCATION { 's3://bucket/folder/' }"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); @@ -91,7 +94,7 @@ void testErrorScenarios() { sqlQuery = "DESCRIBE TABLE FROM myS3.default.alb_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("FROM", fullyQualifiedTableName.getFullyQualifiedName()); Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("FROM", fullyQualifiedTableName.getTableName()); @@ -103,59 +106,165 @@ void testExtractionFromFlintIndexQueries() { String createCoveredIndexQuery = "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - Assertions.assertTrue(SQLQueryUtils.isIndexQuery(createCoveredIndexQuery)); - IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); - FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertEquals("elb_and_requestUri", indexDetails.getIndexName()); + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(createCoveredIndexQuery)); + IndexQueryDetails indexQueryDetails = + SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + FullyQualifiedTableName fullyQualifiedTableName = + indexQueryDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("elb_and_requestUri", indexQueryDetails.getIndexName()); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); } + @Test + void testExtractionFromFlintMVQuery() { + String createCoveredIndexQuery = + "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + + " (auto_refresh = true)"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(createCoveredIndexQuery)); + IndexQueryDetails indexQueryDetails = + SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + FullyQualifiedTableName fullyQualifiedTableName = + indexQueryDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexQueryDetails.getIndexName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals("mv_1", indexQueryDetails.getMvName()); + } + + @Test + void testDescIndex() { + String descSkippingIndex = "DESC SKIPPING INDEX ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descSkippingIndex)); + IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(descSkippingIndex); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + + String descCoveringIndex = "DESC INDEX cv1 ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descCoveringIndex)); + indexDetails = SQLQueryUtils.extractIndexDetails(descCoveringIndex); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("cv1", indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + + String descMv = "DESC MATERIALIZED VIEW mv1"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descMv)); + indexDetails = SQLQueryUtils.extractIndexDetails(descMv); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertEquals("mv1", indexDetails.getMvName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + } + + @Test + void testShowIndex() { + String showCoveringIndex = " SHOW INDEX ON myS3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(showCoveringIndex)); + IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(showCoveringIndex); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNull(indexDetails.getMvName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); + + String showMV = "SHOW MATERIALIZED VIEW IN my_glue.default"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(showMV)); + indexDetails = SQLQueryUtils.extractIndexDetails(showMV); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNull(indexDetails.getMvName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); + } + + @Test + void testRefreshIndex() { + String refreshSkippingIndex = "REFRESH SKIPPING INDEX ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshSkippingIndex)); + IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(refreshSkippingIndex); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + + String refreshCoveringIndex = "REFRESH INDEX cv1 ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshCoveringIndex)); + indexDetails = SQLQueryUtils.extractIndexDetails(refreshCoveringIndex); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("cv1", indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + + String refreshMV = "REFRESH MATERIALIZED VIEW mv1"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshMV)); + indexDetails = SQLQueryUtils.extractIndexDetails(refreshMV); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertEquals("mv1", indexDetails.getMvName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + } + /** https://github.com/opensearch-project/sql/issues/2206 */ @Test void testAutoRefresh() { Assertions.assertFalse( - SQLQueryUtils.extractIndexDetails(skippingIndex().getQuery()).getAutoRefresh()); + SQLQueryUtils.extractIndexDetails(skippingIndex().getQuery()).isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "false").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "true").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"auto_refresh\"", "true").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"auto_refresh\"", "\"true\"").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "1").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails(skippingIndex().withProperty("interval", "1").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); - Assertions.assertFalse(SQLQueryUtils.extractIndexDetails(index().getQuery()).getAutoRefresh()); + Assertions.assertFalse(SQLQueryUtils.extractIndexDetails(index().getQuery()).isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails(index().withProperty("auto_refresh", "false").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails(index().withProperty("auto_refresh", "true").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); + + Assertions.assertTrue( + SQLQueryUtils.extractIndexDetails(mv().withProperty("auto_refresh", "true").getQuery()) + .isAutoRefresh()); } @Getter @@ -176,6 +285,11 @@ public static IndexQuery index() { "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, " + "l_quantity)"); } + public static IndexQuery mv() { + return new IndexQuery( + "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs"); + } + public IndexQuery withProperty(String key, String value) { query = String.format("%s with (%s = %s)", query, key, value); return this; diff --git a/sql/build.gradle b/sql/build.gradle index 2984158e57..c9b46d38f1 100644 --- a/sql/build.gradle +++ b/sql/build.gradle @@ -47,7 +47,7 @@ dependencies { implementation "org.antlr:antlr4-runtime:4.7.1" implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - implementation group: 'org.json', name: 'json', version:'20230227' + implementation group: 'org.json', name: 'json', version:'20231013' implementation project(':common') implementation project(':core') api project(':protocol')