-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance batch job task management by adding default action types #3080
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
import lombok.Builder; | ||
import lombok.EqualsAndHashCode; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
|
||
@Getter | ||
@EqualsAndHashCode | ||
|
@@ -36,6 +37,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { | |
public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function"; | ||
public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function"; | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to add this empty line. |
||
private ActionType actionType; | ||
private String method; | ||
private String url; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
package org.opensearch.ml.common.output; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
|
||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
@@ -29,9 +30,11 @@ public class MLPredictionOutput extends MLOutput { | |
public static final String TASK_ID_FIELD = "task_id"; | ||
public static final String STATUS_FIELD = "status"; | ||
public static final String PREDICTION_RESULT_FIELD = "prediction_result"; | ||
public static final String REMOTE_JOB_FIELD = "remote_job"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add a comment about this field. |
||
|
||
String taskId; | ||
String status; | ||
Map<String, Object> remoteJob; | ||
|
||
@ToString.Exclude | ||
DataFrame predictionResult; | ||
|
@@ -44,6 +47,14 @@ public MLPredictionOutput(String taskId, String status, DataFrame predictionResu | |
this.predictionResult = predictionResult; | ||
} | ||
|
||
@Builder | ||
public MLPredictionOutput(String taskId, String status, Map<String, Object> remoteJob) { | ||
super(OUTPUT_TYPE); | ||
this.taskId = taskId; | ||
this.status = status; | ||
this.remoteJob = remoteJob; | ||
} | ||
|
||
public MLPredictionOutput(StreamInput in) throws IOException { | ||
super(OUTPUT_TYPE); | ||
this.taskId = in.readOptionalString(); | ||
|
@@ -56,6 +67,9 @@ public MLPredictionOutput(StreamInput in) throws IOException { | |
break; | ||
} | ||
} | ||
if (in.readBoolean()) { | ||
this.remoteJob = in.readMap(s -> s.readString(), s -> s.readGenericValue()); | ||
} | ||
} | ||
|
||
@Override | ||
|
@@ -69,6 +83,12 @@ public void writeTo(StreamOutput out) throws IOException { | |
} else { | ||
out.writeBoolean(false); | ||
} | ||
if (remoteJob != null) { | ||
out.writeBoolean(true); | ||
out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
} | ||
|
||
@Override | ||
|
@@ -87,6 +107,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws | |
builder.endObject(); | ||
} | ||
|
||
if (remoteJob != null) { | ||
builder.field(REMOTE_JOB_FIELD, remoteJob); | ||
} | ||
|
||
builder.endObject(); | ||
return builder; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,15 +8,18 @@ | |
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; | ||
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT; | ||
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; | ||
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG; | ||
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; | ||
|
||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
import org.apache.hc.core5.http.HttpStatus; | ||
import org.apache.commons.text.StringSubstitutor; | ||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.OpenSearchStatusException; | ||
import org.opensearch.ResourceNotFoundException; | ||
|
@@ -38,6 +41,7 @@ | |
import org.opensearch.ml.common.MLTask; | ||
import org.opensearch.ml.common.MLTaskType; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.connector.ConnectorAction; | ||
import org.opensearch.ml.common.connector.ConnectorAction.ActionType; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.exception.MLResourceNotFoundException; | ||
|
@@ -210,6 +214,11 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel | |
|
||
private void executeConnector(Connector connector, MLInput mlInput, ActionListener<MLCancelBatchJobResponse> actionListener) { | ||
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { | ||
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name()); | ||
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) { | ||
ConnectorAction connectorAction = createConnectorAction(connector); | ||
connector.setAction(connectorAction); | ||
} | ||
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); | ||
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader | ||
.initInstance(connector.getProtocol(), connector, Connector.class); | ||
|
@@ -245,4 +254,61 @@ private void processTaskResponse(MLTaskResponse taskResponse, ActionListener<MLC | |
log.error("Unable to fetch status for ml task ", e); | ||
} | ||
} | ||
|
||
// TODO: move this method to connector utils class | ||
private ConnectorAction createConnectorAction(Connector connector) { | ||
Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name()); | ||
|
||
Map<String, String> headers = batchPredictAction.get().getHeaders(); | ||
|
||
String predictEndpoint = batchPredictAction.get().getUrl(); | ||
Map<String, String> parameters = connector.getParameters() != null | ||
? new HashMap<>(connector.getParameters()) | ||
: Collections.emptyMap(); | ||
|
||
if (!parameters.isEmpty()) { | ||
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); | ||
predictEndpoint = substitutor.replace(predictEndpoint); | ||
} | ||
|
||
String url = ""; | ||
String requestBody = ""; | ||
String method = "POST"; // Default method | ||
|
||
switch (getEndpointType(predictEndpoint)) { | ||
case "sagemaker": | ||
url = predictEndpoint.replace("CreateTransformJob", "StopTransformJob"); | ||
requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}"; | ||
break; | ||
case "openai": | ||
case "cohere": | ||
url = predictEndpoint + "/${parameters.id}/cancel"; | ||
break; | ||
case "bedrock": | ||
url = predictEndpoint + "/${parameters.processedJobArn}/stop"; | ||
break; | ||
} | ||
|
||
return ConnectorAction | ||
.builder() | ||
.actionType(CANCEL_BATCH_PREDICT) | ||
.method(method) | ||
.url(url) | ||
.requestBody(requestBody) | ||
.headers(headers) | ||
.build(); | ||
|
||
} | ||
|
||
private String getEndpointType(String url) { | ||
if (url.contains("sagemaker")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make static variables for all these values and then re-use? |
||
return "sagemaker"; | ||
if (url.contains("openai")) | ||
return "openai"; | ||
if (url.contains("bedrock")) | ||
return "bedrock"; | ||
if (url.contains("cohere")) | ||
return "cohere"; | ||
return ""; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import static org.opensearch.ml.common.MLTaskState.CANCELLING; | ||
import static org.opensearch.ml.common.MLTaskState.COMPLETED; | ||
import static org.opensearch.ml.common.MLTaskState.EXPIRED; | ||
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT; | ||
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS; | ||
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX; | ||
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX; | ||
|
@@ -24,6 +25,7 @@ | |
import static org.opensearch.ml.utils.MLExceptionUtils.logException; | ||
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; | ||
|
||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
@@ -32,6 +34,7 @@ | |
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
import org.apache.commons.text.StringSubstitutor; | ||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.OpenSearchStatusException; | ||
import org.opensearch.ResourceNotFoundException; | ||
|
@@ -55,6 +58,7 @@ | |
import org.opensearch.ml.common.MLTask; | ||
import org.opensearch.ml.common.MLTaskType; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.connector.ConnectorAction; | ||
import org.opensearch.ml.common.connector.ConnectorAction.ActionType; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.exception.MLResourceNotFoundException; | ||
|
@@ -279,6 +283,11 @@ private void executeConnector( | |
ActionListener<MLTaskGetResponse> actionListener | ||
) { | ||
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { | ||
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name()); | ||
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) { | ||
ConnectorAction connectorAction = createConnectorAction(connector); | ||
connector.setAction(connectorAction); | ||
} | ||
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); | ||
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader | ||
.initInstance(connector.getProtocol(), connector, Connector.class); | ||
|
@@ -362,4 +371,62 @@ private boolean matchesPattern(Pattern pattern, String input) { | |
Matcher matcher = pattern.matcher(input); | ||
return matcher.find(); | ||
} | ||
|
||
// TODO: move this method to connector utils class | ||
private ConnectorAction createConnectorAction(Connector connector) { | ||
Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name()); | ||
|
||
Map<String, String> headers = batchPredictAction.get().getHeaders(); | ||
|
||
String predictEndpoint = batchPredictAction.get().getUrl(); | ||
Map<String, String> parameters = connector.getParameters() != null | ||
? new HashMap<>(connector.getParameters()) | ||
: Collections.emptyMap(); | ||
|
||
// Apply parameter substitution only if needed | ||
if (!parameters.isEmpty()) { | ||
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); | ||
predictEndpoint = substitutor.replace(predictEndpoint); | ||
} | ||
|
||
String url = ""; | ||
String requestBody = ""; | ||
String method = "GET"; | ||
|
||
switch (getEndpointType(predictEndpoint)) { | ||
case "sagemaker": | ||
url = predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob"); | ||
requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}"; | ||
method = "POST"; | ||
break; | ||
case "openai": | ||
case "cohere": | ||
url = predictEndpoint + "/${parameters.id}"; | ||
break; | ||
case "bedrock": | ||
url = predictEndpoint + "/${parameters.processedJobArn}"; | ||
break; | ||
} | ||
return ConnectorAction | ||
.builder() | ||
.actionType(BATCH_PREDICT_STATUS) | ||
.method(method) | ||
.url(url) | ||
.requestBody(requestBody) | ||
.headers(headers) | ||
.build(); | ||
|
||
} | ||
|
||
private String getEndpointType(String url) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need same method 2 times? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes will be moving the common methods to ConnectorUtils class |
||
if (url.contains("sagemaker")) | ||
return "sagemaker"; | ||
if (url.contains("openai")) | ||
return "openai"; | ||
if (url.contains("bedrock")) | ||
return "bedrock"; | ||
if (url.contains("cohere")) | ||
return "cohere"; | ||
return ""; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using addAction as the name is more accurate?