Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change emr job names based on the query type #2543

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spark/src/main/antlr/SqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ primaryExpression
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
| CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
| name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast
| primaryExpression collateClause #collate
| primaryExpression collateClause #collate
vamsimanohar marked this conversation as resolved.
Show resolved Hide resolved
| primaryExpression DOUBLE_COLON dataType #castByColon
| STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct
| FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first
Expand Down Expand Up @@ -1096,7 +1096,7 @@ colPosition
;

collateClause
: COLLATE collationName=stringLit
: COLLATE collationName=identifier
;

type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.amazonaws.services.emrserverless.model.StartJobRunResult;
import java.security.AccessController;
import java.security.PrivilegedAction;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.legacy.metrics.MetricName;
Expand All @@ -29,6 +30,8 @@ public class EmrServerlessClientImpl implements EMRServerlessClient {
private final AWSEMRServerless emrServerless;
private static final Logger logger = LogManager.getLogger(EmrServerlessClientImpl.class);

private static final int MAX_JOB_NAME_LENGTH = 255;

private static final String GENERIC_INTERNAL_SERVER_ERROR_MESSAGE = "Internal Server Error.";

public EmrServerlessClientImpl(AWSEMRServerless emrServerless) {
Expand All @@ -43,7 +46,7 @@ public String startJobRun(StartJobRequest startJobRequest) {
: startJobRequest.getResultIndex();
StartJobRunRequest request =
new StartJobRunRequest()
.withName(startJobRequest.getJobName())
.withName(StringUtils.truncate(startJobRequest.getJobName(), MAX_JOB_NAME_LENGTH))
.withApplicationId(startJobRequest.getApplicationId())
.withExecutionRoleArn(startJobRequest.getExecutionRoleArn())
.withTags(startJobRequest.getTags())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,14 @@ public DispatchQueryResponse submit(
leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource()));

String clusterName = dispatchQueryRequest.getClusterName();
String jobName = clusterName + ":" + "non-index-query";
Map<String, String> tags = context.getTags();
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();

tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
clusterName + ":" + JobType.BATCH.getText(),
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
Expand All @@ -44,10 +42,6 @@ public class IndexDMLHandler extends AsyncQueryHandler {

private final EMRServerlessClient emrServerlessClient;

private final DataSourceService dataSourceService;

private final DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper;

private final JobExecutionResponseReader jobExecutionResponseReader;

private final FlintIndexMetadataReader flintIndexMetadataReader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public DispatchQueryResponse submit(
DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) {
Session session = null;
String clusterName = dispatchQueryRequest.getClusterName();
String jobName = clusterName + ":" + "non-index-query";
Map<String, String> tags = context.getTags();
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();

Expand All @@ -94,7 +93,7 @@ public DispatchQueryResponse submit(
session =
sessionManager.createSession(
new CreateSessionRequest(
jobName,
clusterName,
Comment on lines -97 to +96
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit of transferring the jobname formatting to the callee side? I'd rather leave them at handler level for consistency.

Copy link
Member Author

@vamsimanohar vamsimanohar Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to do that but it requires a lot of refactoring and the current code structure is little convoluted.
Currently, sessionId creation happen inside CreateSessionRequest and so the jobName can't be built on the Callee side.

dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
dataSourceService,
dataSourceUserAuthorizationHelper,
jobExecutionResponseReader,
flintIndexMetadataReader,
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ public DispatchQueryResponse submit(
leaseManager.borrow(new LeaseRequest(JobType.STREAMING, dispatchQueryRequest.getDatasource()));

String clusterName = dispatchQueryRequest.getClusterName();
String jobName = clusterName + ":" + "index-query";
IndexQueryDetails indexQueryDetails = context.getIndexQueryDetails();
Map<String, String> tags = context.getTags();
tags.put(INDEX_TAG_KEY, indexQueryDetails.openSearchIndexName());
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();
tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText());
String jobName =
clusterName
+ ":"
+ JobType.STREAMING.getText()
+ ":"
+ indexQueryDetails.openSearchIndexName();
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
import lombok.Data;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.JobType;

@Data
public class CreateSessionRequest {
private final String jobName;
private final String clusterName;
private final String applicationId;
private final String executionRoleArn;
private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder;
private final Map<String, String> tags;
private final String resultIndex;
private final String datasourceName;

public StartJobRequest getStartJobRequest() {
public StartJobRequest getStartJobRequest(String sessionId) {
return new InteractiveSessionStartJobRequest(
"select 1",
jobName,
clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId,
applicationId,
executionRoleArn,
sparkSubmitParametersBuilder.build().toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
import org.opensearch.sql.spark.execution.statement.Statement;
import org.opensearch.sql.spark.execution.statement.StatementId;
Expand Down Expand Up @@ -55,8 +56,10 @@ public void open(CreateSessionRequest createSessionRequest) {
.getSparkSubmitParametersBuilder()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId());
String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest());
String applicationId = createSessionRequest.getStartJobRequest().getApplicationId();
StartJobRequest startJobRequest =
createSessionRequest.getStartJobRequest(sessionId.getSessionId());
String jobID = serverlessClient.startJobRun(startJobRequest);
String applicationId = startJobRequest.getApplicationId();

sessionModel =
initInteractiveSession(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.amazonaws.services.emrserverless.model.ValidationException;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -175,4 +176,25 @@ void testCancelJobRunWithValidationException() {
() -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID));
Assertions.assertEquals("Internal Server Error.", runtimeException.getMessage());
}

@Test
void testStartJobRunWithLongJobName() {
StartJobRunResult response = new StartJobRunResult();
when(emrServerless.startJobRun(any())).thenReturn(response);

EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
emrServerlessClient.startJobRun(
new StartJobRequest(
QUERY,
RandomStringUtils.random(300),
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
SPARK_SUBMIT_PARAMETERS,
new HashMap<>(),
false,
DEFAULT_RESULT_INDEX));
verify(emrServerless, times(1)).startJobRun(startJobRunRequestArgumentCaptor.capture());
StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue();
Assertions.assertEquals(255, startJobRunRequest.getName().length());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class IndexDMLHandlerTest {
@Test
public void getResponseFromExecutor() {
JSONObject result =
new IndexDMLHandler(null, null, null, null, null, null, null).getResponseFromExecutor(null);
new IndexDMLHandler(null, null, null, null, null).getResponseFromExecutor(null);

assertEquals("running", result.getString(STATUS_FIELD));
assertEquals("", result.getString(ERROR_FIELD));
Expand Down
Loading
Loading