Skip to content

Commit

Permalink
Change emr job names based on the query type (#2543)
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <[email protected]>
  • Loading branch information
vmmusings authored Mar 11, 2024
1 parent e8b7419 commit 1a09f96
Show file tree
Hide file tree
Showing 14 changed files with 185 additions and 260 deletions.
4 changes: 2 additions & 2 deletions spark/src/main/antlr/SqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ primaryExpression
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
| CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
| name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast
| primaryExpression collateClause #collate
| primaryExpression collateClause #collate
| primaryExpression DOUBLE_COLON dataType #castByColon
| STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct
| FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first
Expand Down Expand Up @@ -1096,7 +1096,7 @@ colPosition
;

collateClause
: COLLATE collationName=stringLit
: COLLATE collationName=identifier
;

type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.amazonaws.services.emrserverless.model.StartJobRunResult;
import java.security.AccessController;
import java.security.PrivilegedAction;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.legacy.metrics.MetricName;
Expand All @@ -29,6 +30,8 @@ public class EmrServerlessClientImpl implements EMRServerlessClient {
private final AWSEMRServerless emrServerless;
private static final Logger logger = LogManager.getLogger(EmrServerlessClientImpl.class);

private static final int MAX_JOB_NAME_LENGTH = 255;

private static final String GENERIC_INTERNAL_SERVER_ERROR_MESSAGE = "Internal Server Error.";

public EmrServerlessClientImpl(AWSEMRServerless emrServerless) {
Expand All @@ -43,7 +46,7 @@ public String startJobRun(StartJobRequest startJobRequest) {
: startJobRequest.getResultIndex();
StartJobRunRequest request =
new StartJobRunRequest()
.withName(startJobRequest.getJobName())
.withName(StringUtils.truncate(startJobRequest.getJobName(), MAX_JOB_NAME_LENGTH))
.withApplicationId(startJobRequest.getApplicationId())
.withExecutionRoleArn(startJobRequest.getExecutionRoleArn())
.withTags(startJobRequest.getTags())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,14 @@ public DispatchQueryResponse submit(
leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource()));

String clusterName = dispatchQueryRequest.getClusterName();
String jobName = clusterName + ":" + "non-index-query";
Map<String, String> tags = context.getTags();
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();

tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
clusterName + ":" + JobType.BATCH.getText(),
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
Expand All @@ -44,10 +42,6 @@ public class IndexDMLHandler extends AsyncQueryHandler {

private final EMRServerlessClient emrServerlessClient;

private final DataSourceService dataSourceService;

private final DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper;

private final JobExecutionResponseReader jobExecutionResponseReader;

private final FlintIndexMetadataReader flintIndexMetadataReader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public DispatchQueryResponse submit(
DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) {
Session session = null;
String clusterName = dispatchQueryRequest.getClusterName();
String jobName = clusterName + ":" + "non-index-query";
Map<String, String> tags = context.getTags();
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();

Expand All @@ -94,7 +93,7 @@ public DispatchQueryResponse submit(
session =
sessionManager.createSession(
new CreateSessionRequest(
jobName,
clusterName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
dataSourceService,
dataSourceUserAuthorizationHelper,
jobExecutionResponseReader,
flintIndexMetadataReader,
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ public DispatchQueryResponse submit(
leaseManager.borrow(new LeaseRequest(JobType.STREAMING, dispatchQueryRequest.getDatasource()));

String clusterName = dispatchQueryRequest.getClusterName();
String jobName = clusterName + ":" + "index-query";
IndexQueryDetails indexQueryDetails = context.getIndexQueryDetails();
Map<String, String> tags = context.getTags();
tags.put(INDEX_TAG_KEY, indexQueryDetails.openSearchIndexName());
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();
tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText());
String jobName =
clusterName
+ ":"
+ JobType.STREAMING.getText()
+ ":"
+ indexQueryDetails.openSearchIndexName();
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
import lombok.Data;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.JobType;

@Data
public class CreateSessionRequest {
private final String jobName;
private final String clusterName;
private final String applicationId;
private final String executionRoleArn;
private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder;
private final Map<String, String> tags;
private final String resultIndex;
private final String datasourceName;

public StartJobRequest getStartJobRequest() {
public StartJobRequest getStartJobRequest(String sessionId) {
return new InteractiveSessionStartJobRequest(
"select 1",
jobName,
clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId,
applicationId,
executionRoleArn,
sparkSubmitParametersBuilder.build().toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
import org.opensearch.sql.spark.execution.statement.Statement;
import org.opensearch.sql.spark.execution.statement.StatementId;
Expand Down Expand Up @@ -55,8 +56,10 @@ public void open(CreateSessionRequest createSessionRequest) {
.getSparkSubmitParametersBuilder()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId());
String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest());
String applicationId = createSessionRequest.getStartJobRequest().getApplicationId();
StartJobRequest startJobRequest =
createSessionRequest.getStartJobRequest(sessionId.getSessionId());
String jobID = serverlessClient.startJobRun(startJobRequest);
String applicationId = startJobRequest.getApplicationId();

sessionModel =
initInteractiveSession(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.amazonaws.services.emrserverless.model.ValidationException;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -175,4 +176,25 @@ void testCancelJobRunWithValidationException() {
() -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID));
Assertions.assertEquals("Internal Server Error.", runtimeException.getMessage());
}

@Test
void testStartJobRunWithLongJobName() {
StartJobRunResult response = new StartJobRunResult();
when(emrServerless.startJobRun(any())).thenReturn(response);

EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
emrServerlessClient.startJobRun(
new StartJobRequest(
QUERY,
RandomStringUtils.random(300),
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
SPARK_SUBMIT_PARAMETERS,
new HashMap<>(),
false,
DEFAULT_RESULT_INDEX));
verify(emrServerless, times(1)).startJobRun(startJobRunRequestArgumentCaptor.capture());
StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue();
Assertions.assertEquals(255, startJobRunRequest.getName().length());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class IndexDMLHandlerTest {
@Test
public void getResponseFromExecutor() {
JSONObject result =
new IndexDMLHandler(null, null, null, null, null, null, null).getResponseFromExecutor(null);
new IndexDMLHandler(null, null, null, null, null).getResponseFromExecutor(null);

assertEquals("running", result.getString(STATUS_FIELD));
assertEquals("", result.getString(ERROR_FIELD));
Expand Down
Loading

0 comments on commit 1a09f96

Please sign in to comment.