diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java index e0b075f567..ce6fa19145 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java @@ -75,6 +75,55 @@ protected double calculateSuccessRate(List successRates) { ); } + /** + * Filters fields in the map where the value contains the specified source index as a prefix. + * When there is only one source file, users can skip the source[] prefix + * + * @param mlBatchIngestionInput The MLBatchIngestionInput. + * @return A new map of for all fields to be ingested. + */ + protected Map filterFieldMappingSoleSource(MLBatchIngestionInput mlBatchIngestionInput) { + Map fieldMap = mlBatchIngestionInput.getFieldMapping(); + String prefix = "source[0]"; + + Map filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> { + Object value = entry.getValue(); + if (value instanceof String) { + String jsonPath = ((String) value); + return jsonPath.contains(prefix) || !jsonPath.startsWith("source"); + } else if (value instanceof List) { + return ((List) value).stream().anyMatch(val -> (val.contains(prefix) || !val.startsWith("source"))); + } + return false; + }).collect(Collectors.toMap(Map.Entry::getKey, entry -> { + Object value = entry.getValue(); + if (value instanceof String) { + return getJsonPath((String) value); + } else if (value instanceof List) { + return ((List) value) + .stream() + .filter(val -> (val.contains(prefix) || !val.startsWith("source"))) + .map(StringUtils::getJsonPath) + .collect(Collectors.toList()); + } + return null; + })); + + String[] ingestFields = mlBatchIngestionInput.getIngestFields(); + if (ingestFields != null) { + Arrays + .stream(ingestFields) + .filter(Objects::nonNull) + .filter(val -> (val.contains(prefix) || !val.startsWith("source"))) + .map(StringUtils::getJsonPath) + .forEach(jsonPath -> { + filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath); + }); + } + + return filteredFieldMap; + } + /** * Filters fields in the map where the value contains the specified source index as a prefix. * @@ -159,7 +208,7 @@ protected void batchIngest( BulkRequest bulkRequest = new BulkRequest(); sourceLines.stream().forEach(jsonStr -> { Map filteredMapping = isSoleSource - ? mlBatchIngestionInput.getFieldMapping() + ? filterFieldMappingSoleSource(mlBatchIngestionInput) : filterFieldMapping(mlBatchIngestionInput, sourceIndex); Map jsonMap = processFieldMapping(jsonStr, filteredMapping); if (jsonMap.isEmpty()) { @@ -174,7 +223,7 @@ protected void batchIngest( if (!jsonMap.containsKey("_id")) { throw new IllegalArgumentException("The id filed must be provided to match documents for multiple sources"); } - String id = (String) jsonMap.remove("_id"); + String id = String.valueOf(jsonMap.remove("_id")); UpdateRequest updateRequest = new UpdateRequest(mlBatchIngestionInput.getIndexName(), id).doc(jsonMap).upsert(jsonMap); bulkRequest.add(updateRequest); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java index a4c155ba77..1f1653b31c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java @@ -204,6 +204,35 @@ public void testFilterFieldMapping_MatchingPrefix() { assertEquals(Arrays.asList("$.custom_id"), result.get("_id")); } + @Test + public void testFilterFieldMappingSoleSource_MatchingPrefix() { + // Arrange + Map fieldMap = new HashMap<>(); + fieldMap.put("question", "source[0].$.body.input[0]"); + fieldMap.put("question_embedding", "source[0].$.response.body.data[0].embedding"); + fieldMap.put("answer", "source[0].$.body.input[1]"); + fieldMap.put("answer_embedding", "$.response.body.data[1].embedding"); + fieldMap.put("_id", Arrays.asList("$.custom_id", "source[1].$.custom_id")); + + MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput( + "indexName", + fieldMap, + ingestFields, + new HashMap<>(), + new HashMap<>() + ); + + // Act + Map result = s3DataIngestion.filterFieldMappingSoleSource(mlBatchIngestionInput); + + // Assert + assertEquals(6, result.size()); + + assertEquals("$.body.input[0]", result.get("question")); + assertEquals("$.response.body.data[0].embedding", result.get("question_embedding")); + assertEquals(Arrays.asList("$.custom_id"), result.get("_id")); + } + @Test public void testProcessFieldMapping_FromSM() { String jsonStr = diff --git a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java index cf03d0f11a..6fd03b7b52 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java @@ -9,7 +9,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.COMPLETED; import static org.opensearch.ml.common.MLTaskState.FAILED; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import java.time.Instant; @@ -41,6 +41,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.jayway.jsonpath.PathNotFoundException; + import lombok.extern.log4j.Log4j2; @Log4j2 @@ -92,9 +94,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - double successRate = ingestable.ingest(mlBatchIngestionInput); - handleSuccessRate(successRate, taskId); + threadPool.executor(INGEST_THREAD_POOL).execute(() -> { + executeWithErrorHandling(() -> { + double successRate = ingestable.ingest(mlBatchIngestionInput); + handleSuccessRate(successRate, taskId); + }, taskId); }); } catch (Exception ex) { log.error("Failed in batch ingestion", ex); @@ -125,6 +129,30 @@ protected void doExecute(Task task, ActionRequest request, ActionListener> getExecutorBuilders(Settings settings) { ML_THREAD_POOL_PREFIX + REMOTE_PREDICT_THREAD_POOL, false ); + FixedExecutorBuilder batchIngestThreadPool = new FixedExecutorBuilder( + settings, + INGEST_THREAD_POOL, + OpenSearchExecutors.allocatedProcessors(settings) * 4, + 30, + ML_THREAD_POOL_PREFIX + INGEST_THREAD_POOL, + false + ); return ImmutableList .of( @@ -894,7 +903,8 @@ public List> getExecutorBuilders(Settings settings) { executeThreadPool, trainThreadPool, predictThreadPool, - remotePredictThreadPool + remotePredictThreadPool, + batchIngestThreadPool ); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 6daffd30fd..339116226d 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -146,7 +146,8 @@ private MLCommonsSettings() {} "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$", "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", - "^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" + "^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" ), Function.identity(), Setting.Property.NodeScope, diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 3b2e70d4b8..525ae12a88 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -358,13 +358,13 @@ private void runPredict( && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) { ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0); + Integer statusCode = modelOutput.getStatusCode(); if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) { Map dataAsMap = (Map) modelOutput .getMlModelTensors() .get(0) .getDataAsMap(); - if (dataAsMap != null - && (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) { + if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) { remoteJob.putAll(dataAsMap); mlTask.setRemoteJob(remoteJob); mlTask.setTaskId(null); diff --git a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java index 2916359110..092edfe951 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java @@ -6,9 +6,14 @@ package org.opensearch.ml.action.batch; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.MLTask.ERROR_FIELD; @@ -16,12 +21,14 @@ import static org.opensearch.ml.common.MLTaskState.COMPLETED; import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE; +import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ExecutorService; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -45,6 +52,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.jayway.jsonpath.PathNotFoundException; + public class TransportBatchIngestionActionTests extends OpenSearchTestCase { @Mock private Client client; @@ -62,6 +71,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { ActionListener actionListener; @Mock ThreadPool threadPool; + @Mock + ExecutorService executorService; private TransportBatchIngestionAction batchAction; private MLBatchIngestionInput batchInput; @@ -105,9 +116,42 @@ public void test_doExecute_success() { listener.onResponse(indexResponse); return null; }).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class)); + doReturn(executorService).when(threadPool).executor(INGEST_THREAD_POOL); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); verify(actionListener).onResponse(any(MLBatchIngestionResponse.class)); + verify(threadPool).executor(INGEST_THREAD_POOL); + } + + public void test_doExecute_ExecuteWithNoErrorHandling() { + batchAction.executeWithErrorHandling(() -> {}, "taskId"); + + verify(mlTaskManager, never()).updateMLTask(anyString(), isA(Map.class), anyLong(), anyBoolean()); + } + + public void test_doExecute_ExecuteWithPathNotFoundException() { + batchAction.executeWithErrorHandling(() -> { throw new PathNotFoundException("jsonPath not found!"); }, "taskId"); + + verify(mlTaskManager) + .updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "jsonPath not found!"), TASK_SEMAPHORE_TIMEOUT, true); + } + + public void test_doExecute_RuntimeException() { + batchAction.executeWithErrorHandling(() -> { throw new RuntimeException("runtime exception in the ingestion!"); }, "taskId"); + + verify(mlTaskManager) + .updateMLTask( + "taskId", + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "runtime exception in the ingestion!"), + TASK_SEMAPHORE_TIMEOUT, + true + ); } public void test_doExecute_handleSuccessRate100() { diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 064008a9c4..223f2ce5a5 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -447,10 +447,9 @@ public void testValidateBatchPredictionSuccess() throws IOException { "output", "{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}" ); - ModelTensorOutput modelTensorOutput = ModelTensorOutput - .builder() - .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) - .build(); + ModelTensors modelTensors = ModelTensors.builder().statusCode(200).mlModelTensors(List.of(modelTensor)).statusCode(200).build(); + modelTensors.setStatusCode(200); + ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build());