From ff6fe67b6b913b8c746faac04fc7e2f2b5d184e7 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Wed, 9 Oct 2024 17:52:54 -0500 Subject: [PATCH] Enhance batch job task management by adding default action types (#3080) * enhance batch job task management by adding default action types Signed-off-by: Bhavana Ramaram --- .../common/connector/AbstractConnector.java | 5 + .../ml/common/connector/Connector.java | 2 + .../ml/common/output/MLPredictionOutput.java | 26 ++++ .../common/output/MLPredictionOutputTest.java | 14 ++ .../algorithms/remote/ConnectorUtils.java | 64 ++++++++ .../algorithms/remote/ConnectorUtilsTest.java | 141 ++++++++++++++++-- .../tasks/CancelBatchJobTransportAction.java | 7 + .../action/tasks/GetTaskTransportAction.java | 7 + .../ml/task/MLPredictTaskRunner.java | 10 +- .../tasks/GetTaskTransportActionTests.java | 8 + 10 files changed, 266 insertions(+), 18 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 4849f79c93..d8adc7ac54 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -125,6 +125,11 @@ public Optional findAction(String action) { return Optional.empty(); } + @Override + public void addAction(ConnectorAction action) { + actions.add(action); + } + @Override public void removeCredential() { this.credential = null; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 0a37641144..aa2f93235e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -65,6 +65,8 @@ public interface Connector extends ToXContentObject, Writeable { List getActions(); + void addAction(ConnectorAction action); + ConnectorClientConfig getConnectorClientConfig(); String getActionEndpoint(String action, Map parameters); diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java index 5675dab409..830a19623a 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.output; import java.io.IOException; +import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -30,8 +31,12 @@ public class MLPredictionOutput extends MLOutput { public static final String STATUS_FIELD = "status"; public static final String PREDICTION_RESULT_FIELD = "prediction_result"; + // This field will be created for offline batch prediction tasks containing details of the batch job as outputted by the remote server. + public static final String REMOTE_JOB_FIELD = "remote_job"; + String taskId; String status; + Map remoteJob; @ToString.Exclude DataFrame predictionResult; @@ -44,6 +49,14 @@ public MLPredictionOutput(String taskId, String status, DataFrame predictionResu this.predictionResult = predictionResult; } + @Builder + public MLPredictionOutput(String taskId, String status, Map remoteJob) { + super(OUTPUT_TYPE); + this.taskId = taskId; + this.status = status; + this.remoteJob = remoteJob; + } + public MLPredictionOutput(StreamInput in) throws IOException { super(OUTPUT_TYPE); this.taskId = in.readOptionalString(); @@ -56,6 +69,9 @@ public MLPredictionOutput(StreamInput in) throws IOException { break; } } + if (in.readBoolean()) { + this.remoteJob = in.readMap(s -> s.readString(), s -> s.readGenericValue()); + } } @Override @@ -69,6 +85,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (remoteJob != null) { + out.writeBoolean(true); + out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue); + } else { + out.writeBoolean(false); + } } @Override @@ -87,6 +109,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); } + if (remoteJob != null) { + builder.field(REMOTE_JOB_FIELD, remoteJob); + } + builder.endObject(); return builder; } diff --git a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java index 857e92f5a3..49a1e9355b 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java @@ -9,7 +9,9 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -30,6 +32,7 @@ public class MLPredictionOutputTest { MLPredictionOutput output; + MLPredictionOutput outputWithRemoteJob; @Before public void setUp() { @@ -38,12 +41,17 @@ public void setUp() { rows.add(new Row(new ColumnValue[] { new IntValue(1) })); rows.add(new Row(new ColumnValue[] { new IntValue(2) })); DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows); + Map remoteJob = new HashMap<>(); + remoteJob.put("status", "INPROGRESS"); + remoteJob.put("job_id", "testJobID"); output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build(); + outputWithRemoteJob = new MLPredictionOutput("test_task_id", "test_status", remoteJob); } @Test public void toXContent() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + XContentBuilder builderWithRemoteJob = MediaTypeRegistry.contentBuilder(XContentType.JSON); output.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); assertEquals( @@ -53,6 +61,12 @@ public void toXContent() throws IOException { + "\"value\":2}]}]}}", jsonStr ); + outputWithRemoteJob.toXContent(builderWithRemoteJob, ToXContent.EMPTY_PARAMS); + String jsonStr2 = builderWithRemoteJob.toString(); + assertEquals( + "{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"remote_job\":{\"job_id\":\"testJobID\",\"status\":\"INPROGRESS\"}}", + jsonStr2 + ); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index ccceff3d68..f2c93ef5fd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; @@ -19,6 +20,7 @@ import java.net.URI; import java.nio.charset.Charset; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -61,6 +63,9 @@ public class ConnectorUtils { private static final Aws4Signer signer; public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters"; + public static final List SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List + .of("sagemaker", "openai", "bedrock", "cohere"); + static { signer = Aws4Signer.create(); } @@ -313,4 +318,63 @@ public static SdkHttpFullRequest buildSdkRequest( } return builder.build(); } + + public static ConnectorAction createConnectorAction(Connector connector, ConnectorAction.ActionType actionType) { + Optional batchPredictAction = connector.findAction(BATCH_PREDICT.name()); + String predictEndpoint = batchPredictAction.get().getUrl(); + Map parameters = connector.getParameters() != null + ? new HashMap<>(connector.getParameters()) + : Collections.emptyMap(); + + // Apply parameter substitution only if needed + if (!parameters.isEmpty()) { + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + predictEndpoint = substitutor.replace(predictEndpoint); + } + + boolean isCancelAction = actionType == CANCEL_BATCH_PREDICT; + + // Initialize the default method and requestBody + String method = "POST"; + String requestBody = null; + String url = ""; + + switch (getRemoteServerFromURL(predictEndpoint)) { + case "sagemaker": + url = isCancelAction + ? predictEndpoint.replace("CreateTransformJob", "StopTransformJob") + : predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob"); + requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}"; + break; + case "openai": + case "cohere": + url = isCancelAction ? predictEndpoint + "/${parameters.id}/cancel" : predictEndpoint + "/${parameters.id}"; + method = isCancelAction ? "POST" : "GET"; + break; + case "bedrock": + url = isCancelAction + ? predictEndpoint + "/${parameters.processedJobArn}/stop" + : predictEndpoint + "/${parameters.processedJobArn}"; + method = isCancelAction ? "POST" : "GET"; + break; + default: + String errorMessage = isCancelAction + ? "Please configure the action type to cancel the batch job in the connector" + : "Please configure the action type to get the batch job details in the connector"; + throw new UnsupportedOperationException(errorMessage); + } + + return ConnectorAction + .builder() + .actionType(actionType) + .method(method) + .url(url) + .requestBody(requestBody) + .headers(batchPredictAction.get().getHeaders()) + .build(); + } + + public static String getRemoteServerFromURL(String url) { + return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse(""); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 31b0f5e420..335dc95245 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -134,7 +139,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec .actions(Arrays.asList(predictAction)) .build(); ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); - Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); + assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); } @Test @@ -195,9 +200,9 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; ModelTensors tensors = ConnectorUtils .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); - Assert.assertEquals(1, tensors.getMlModelTensors().size()); - Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); - Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size()); + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("response", tensors.getMlModelTensors().get(0).getName()); + assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size()); } @Test @@ -228,13 +233,13 @@ public void processOutput_PostprocessFunction() throws IOException { "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; ModelTensors tensors = ConnectorUtils .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); - Assert.assertEquals(1, tensors.getMlModelTensors().size()); - Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); - Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); - Assert.assertEquals(3, tensors.getMlModelTensors().get(0).getData().length); - Assert.assertEquals(-0.014555434, tensors.getMlModelTensors().get(0).getData()[0]); - Assert.assertEquals(-0.0002135904, tensors.getMlModelTensors().get(0).getData()[1]); - Assert.assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]); + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); + assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); + assertEquals(3, tensors.getMlModelTensors().get(0).getData().length); + assertEquals(-0.014555434, tensors.getMlModelTensors().get(0).getData()[0]); + assertEquals(-0.0002135904, tensors.getMlModelTensors().get(0).getData()[1]); + assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]); } private void processInput_TextDocsInputDataSet_PreprocessFunction( @@ -268,7 +273,117 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction( RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils .processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); - Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); - Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); + assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); + assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); + } + + @Test + public void testGetTask_createBatchStatusActionForSageMaker() { + Connector connector1 = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + new ArrayList<>( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .method("POST") + .url("https://api.sagemaker.us-east-1.amazonaws.com/CreateTransformJob") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}") + .build() + ) + ) + ) + .build(); + + ConnectorAction result = ConnectorUtils.createConnectorAction(connector1, BATCH_PREDICT_STATUS); + + assertEquals(ConnectorAction.ActionType.BATCH_PREDICT_STATUS, result.getActionType()); + assertEquals("POST", result.getMethod()); + assertEquals("https://api.sagemaker.us-east-1.amazonaws.com/DescribeTransformJob", result.getUrl()); + assertEquals("{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}", result.getRequestBody()); + assertTrue(result.getHeaders().containsKey("Authorization")); + + } + + @Test + public void testGetTask_createBatchStatusActionForOpenAI() { + Connector connector1 = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + new ArrayList<>( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .method("POST") + .url("https://api.openai.com/v1/batches") + .headers(Map.of("Authorization", "Bearer ${credential.openAI_key}")) + .requestBody("{ \\\"input_file_id\\\": \\\"${parameters.input_file_id}\\\" }") + .build() + ) + ) + ) + .build(); + + ConnectorAction result = ConnectorUtils.createConnectorAction(connector1, BATCH_PREDICT_STATUS); + + assertEquals(ConnectorAction.ActionType.BATCH_PREDICT_STATUS, result.getActionType()); + assertEquals("GET", result.getMethod()); + assertEquals("https://api.openai.com/v1/batches/${parameters.id}", result.getUrl()); + assertNull(result.getRequestBody()); + assertTrue(result.getHeaders().containsKey("Authorization")); + } + + @Test + public void testGetTask_createCancelBatchActionForBedrock() { + Connector connector1 = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + new ArrayList<>( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .method("POST") + .url("https://bedrock.${parameters.region}.amazonaws.com/model-invocation-job") + .requestBody( + "{\\\"inputDataConfig\\\":{\\\"s3InputDataConfig\\\":{\\\"s3Uri\\\":\\\"${parameters.input_s3Uri}\\\"}},\\\"jobName\\\":\\\"${parameters.job_name}\\\",\\\"modelId\\\":\\\"${parameters.model}\\\",\\\"outputDataConfig\\\":{\\\"s3OutputDataConfig\\\":{\\\"s3Uri\\\":\\\"${parameters.output_s3Uri}\\\"}},\\\"roleArn\\\":\\\"${parameters.role_arn}\\\"}" + ) + .postProcessFunction("connector.post_process.bedrock.batch_job_arn") + .build() + ) + ) + ) + .build(); + + ConnectorAction result = ConnectorUtils.createConnectorAction(connector1, CANCEL_BATCH_PREDICT); + + assertEquals(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT, result.getActionType()); + assertEquals("POST", result.getMethod()); + assertEquals( + "https://bedrock.${parameters.region}.amazonaws.com/model-invocation-job/${parameters.processedJobArn}/stop", + result.getUrl() + ); + assertNull(result.getRequestBody()); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 95e43ca929..b32c6243ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -38,6 +38,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; @@ -49,6 +50,7 @@ import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest; import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse; import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils; import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; @@ -210,6 +212,11 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener actionListener) { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + Optional cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name()); + if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) { + ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT); + connector.addAction(connectorAction); + } connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index e2e9109cf2..35e4e6d83d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -55,6 +55,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; @@ -66,6 +67,7 @@ import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils; import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; @@ -279,6 +281,11 @@ private void executeConnector( ActionListener actionListener ) { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + Optional batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name()); + if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) { + ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS); + connector.addAction(connectorAction); + } connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 525ae12a88..a59d7bbe2b 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -371,11 +371,11 @@ private void runPredict( mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - MLPredictionOutput outputBuilder = MLPredictionOutput - .builder() - .taskId(taskId) - .status(MLTaskState.CREATED.name()) - .build(); + MLPredictionOutput outputBuilder = new MLPredictionOutput( + taskId, + MLTaskState.CREATED.name(), + remoteJob + ); MLTaskResponse predictOutput = MLTaskResponse.builder().output(outputBuilder).build(); internalListener.onResponse(predictOutput); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 25c43eb9b6..1c036adb92 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -193,6 +193,14 @@ public void setup() throws IOException { .actions( Arrays .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .method("POST") + .url("https://api.sagemaker.us-east-1.amazonaws.com/CreateTransformJob") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}") + .build(), ConnectorAction .builder() .actionType(ConnectorAction.ActionType.BATCH_PREDICT_STATUS)