Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix field mapping, add more error handling and remove checking jobId … #2933

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,55 @@ protected double calculateSuccessRate(List<Double> successRates) {
);
}

/**
* Filters fields in the map where the value contains the specified source index as a prefix.
* When there is only one source file, users can skip the source[] prefix
Copy link
Collaborator

@ylwu-amzn ylwu-amzn Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is only one source file, is it ok for cx to input source[0].xxx with this PR ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it works for both cases. this function makes sure it finds the fields with source[0] or any fields that does not prefixed with "source".

*
* @param mlBatchIngestionInput The MLBatchIngestionInput.
* @return A new map of <fieldName: JsonPath> for all fields to be ingested.
*/
protected Map<String, Object> filterFieldMappingSoleSource(MLBatchIngestionInput mlBatchIngestionInput) {
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
String prefix = "source[0]";

Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
Object value = entry.getValue();
if (value instanceof String) {
String jsonPath = ((String) value);
return jsonPath.contains(prefix) || !jsonPath.startsWith("source");
} else if (value instanceof List) {
return ((List<String>) value).stream().anyMatch(val -> (val.contains(prefix) || !val.startsWith("source")));
}
return false;
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
Object value = entry.getValue();
if (value instanceof String) {
return getJsonPath((String) value);
} else if (value instanceof List) {
return ((List<String>) value)
.stream()
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
.map(StringUtils::getJsonPath)
.collect(Collectors.toList());
}
return null;
}));

String[] ingestFields = mlBatchIngestionInput.getIngestFields();
if (ingestFields != null) {
Arrays
.stream(ingestFields)
.filter(Objects::nonNull)
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
.map(StringUtils::getJsonPath)
.forEach(jsonPath -> {
filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath);
});
}

return filteredFieldMap;
}

/**
* Filters fields in the map where the value contains the specified source index as a prefix.
*
Expand Down Expand Up @@ -159,7 +208,7 @@ protected void batchIngest(
BulkRequest bulkRequest = new BulkRequest();
sourceLines.stream().forEach(jsonStr -> {
Map<String, Object> filteredMapping = isSoleSource
? mlBatchIngestionInput.getFieldMapping()
? filterFieldMappingSoleSource(mlBatchIngestionInput)
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
if (jsonMap.isEmpty()) {
Expand All @@ -174,7 +223,7 @@ protected void batchIngest(
if (!jsonMap.containsKey("_id")) {
throw new IllegalArgumentException("The id filed must be provided to match documents for multiple sources");
}
String id = (String) jsonMap.remove("_id");
String id = String.valueOf(jsonMap.remove("_id"));
UpdateRequest updateRequest = new UpdateRequest(mlBatchIngestionInput.getIndexName(), id).doc(jsonMap).upsert(jsonMap);
bulkRequest.add(updateRequest);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,35 @@ public void testFilterFieldMapping_MatchingPrefix() {
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
}

@Test
public void testFilterFieldMappingSoleSource_MatchingPrefix() {
// Arrange
Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put("question", "source[0].$.body.input[0]");
fieldMap.put("question_embedding", "source[0].$.response.body.data[0].embedding");
fieldMap.put("answer", "source[0].$.body.input[1]");
fieldMap.put("answer_embedding", "$.response.body.data[1].embedding");
fieldMap.put("_id", Arrays.asList("$.custom_id", "source[1].$.custom_id"));

MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput(
"indexName",
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
);

// Act
Map<String, Object> result = s3DataIngestion.filterFieldMappingSoleSource(mlBatchIngestionInput);

// Assert
assertEquals(6, result.size());

assertEquals("$.body.input[0]", result.get("question"));
assertEquals("$.response.body.data[0].embedding", result.get("question_embedding"));
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
}

@Test
public void testProcessFieldMapping_FromSM() {
String jsonStr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.time.Instant;
Expand Down Expand Up @@ -41,6 +41,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.jayway.jsonpath.PathNotFoundException;

import lombok.extern.log4j.Log4j2;

@Log4j2
Expand Down Expand Up @@ -92,9 +94,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(TRAIN_THREAD_POOL).execute(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
Expand Down Expand Up @@ -125,6 +129,30 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
}
}

protected void executeWithErrorHandling(Runnable task, String taskId) {
try {
task.run();
} catch (PathNotFoundException jsonPathNotFoundException) {
log.error("Error in jsonParse fields", jsonPathNotFoundException);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, jsonPathNotFoundException.getMessage()),
TASK_SEMAPHORE_TIMEOUT,
true
);
} catch (Exception e) {
log.error("Error in ingest, failed to produce a successRate", e);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e)),
TASK_SEMAPHORE_TIMEOUT,
true
);
}
}

protected void handleSuccessRate(double successRate, String taskId) {
if (successRate == 100) {
mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, COMPLETED), 5000, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ public class MachineLearningPlugin extends Plugin
public static final String TRAIN_THREAD_POOL = "opensearch_ml_train";
public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict";
public static final String REMOTE_PREDICT_THREAD_POOL = "opensearch_ml_predict_remote";
public static final String INGEST_THREAD_POOL = "opensearch_ml_ingest";
public static final String REGISTER_THREAD_POOL = "opensearch_ml_register";
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy";
public static final String ML_BASE_URI = "/_plugins/_ml";
Expand Down Expand Up @@ -885,6 +886,14 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
ML_THREAD_POOL_PREFIX + REMOTE_PREDICT_THREAD_POOL,
false
);
FixedExecutorBuilder batchIngestThreadPool = new FixedExecutorBuilder(
settings,
INGEST_THREAD_POOL,
OpenSearchExecutors.allocatedProcessors(settings) * 4,
30,
ML_THREAD_POOL_PREFIX + INGEST_THREAD_POOL,
false
);

return ImmutableList
.of(
Expand All @@ -894,7 +903,8 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
executeThreadPool,
trainThreadPool,
predictThreadPool,
remotePredictThreadPool
remotePredictThreadPool,
batchIngestThreadPool
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ private MLCommonsSettings() {}
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$",
"^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
),
Function.identity(),
Setting.Property.NodeScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ private void runPredict(
&& tensorOutput.getMlModelOutputs() != null
&& !tensorOutput.getMlModelOutputs().isEmpty()) {
ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0);
Integer statusCode = modelOutput.getStatusCode();
if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) {
Map<String, Object> dataAsMap = (Map<String, Object>) modelOutput
.getMlModelTensors()
.get(0)
.getDataAsMap();
if (dataAsMap != null
&& (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) {
if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) {
remoteJob.putAll(dataAsMap);
mlTask.setRemoteJob(remoteJob);
mlTask.setTaskId(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,29 @@
package org.opensearch.ml.action.batch;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.MLTask.ERROR_FIELD;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -45,6 +52,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.jayway.jsonpath.PathNotFoundException;

public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@Mock
private Client client;
Expand All @@ -62,6 +71,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
ActionListener<MLBatchIngestionResponse> actionListener;
@Mock
ThreadPool threadPool;
@Mock
ExecutorService executorService;

private TransportBatchIngestionAction batchAction;
private MLBatchIngestionInput batchInput;
Expand Down Expand Up @@ -105,9 +116,42 @@ public void test_doExecute_success() {
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class));
doReturn(executorService).when(threadPool).executor(INGEST_THREAD_POOL);
doAnswer(invocation -> {
Runnable runnable = invocation.getArgument(0);
runnable.run();
return null;
}).when(executorService).execute(any(Runnable.class));

batchAction.doExecute(task, mlBatchIngestionRequest, actionListener);

verify(actionListener).onResponse(any(MLBatchIngestionResponse.class));
verify(threadPool).executor(INGEST_THREAD_POOL);
}

public void test_doExecute_ExecuteWithNoErrorHandling() {
batchAction.executeWithErrorHandling(() -> {}, "taskId");

verify(mlTaskManager, never()).updateMLTask(anyString(), isA(Map.class), anyLong(), anyBoolean());
}

public void test_doExecute_ExecuteWithPathNotFoundException() {
batchAction.executeWithErrorHandling(() -> { throw new PathNotFoundException("jsonPath not found!"); }, "taskId");

verify(mlTaskManager)
.updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "jsonPath not found!"), TASK_SEMAPHORE_TIMEOUT, true);
}

public void test_doExecute_RuntimeException() {
batchAction.executeWithErrorHandling(() -> { throw new RuntimeException("runtime exception in the ingestion!"); }, "taskId");

verify(mlTaskManager)
.updateMLTask(
"taskId",
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "runtime exception in the ingestion!"),
TASK_SEMAPHORE_TIMEOUT,
true
);
}

public void test_doExecute_handleSuccessRate100() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,9 @@ public void testValidateBatchPredictionSuccess() throws IOException {
"output",
"{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}"
);
ModelTensorOutput modelTensorOutput = ModelTensorOutput
.builder()
.mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build()))
.build();
ModelTensors modelTensors = ModelTensors.builder().statusCode(200).mlModelTensors(List.of(modelTensor)).statusCode(200).build();
modelTensors.setStatusCode(200);
ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build();
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build());
Expand Down
Loading