Skip to content

Commit

Permalink
fix field mapping, add more error handling and remove checking jobId …
Browse files Browse the repository at this point in the history
…filed in batch job response

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Sep 12, 2024
1 parent 93d0429 commit 57051bd
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,55 @@ protected double calculateSuccessRate(List<Double> successRates) {
);
}

/**
* Filters fields in the map where the value contains the specified source index as a prefix.
* When there is only one source file, users can skip the source[] prefix
*
* @param mlBatchIngestionInput The MLBatchIngestionInput.
* @return A new map of <fieldName: JsonPath> for all fields to be ingested.
*/
protected Map<String, Object> filterFieldMappingSoleSource(MLBatchIngestionInput mlBatchIngestionInput) {
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
String prefix = "source[0]";

Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
Object value = entry.getValue();
if (value instanceof String) {
String jsonPath = ((String) value);
return jsonPath.contains(prefix) || !jsonPath.startsWith("source");
} else if (value instanceof List) {
return ((List<String>) value).stream().anyMatch(val -> (val.contains(prefix) || !val.startsWith("source")));
}
return false;
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
Object value = entry.getValue();
if (value instanceof String) {
return getJsonPath((String) value);
} else if (value instanceof List) {
return ((List<String>) value)
.stream()
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
.map(StringUtils::getJsonPath)
.collect(Collectors.toList());
}
return null;
}));

String[] ingestFields = mlBatchIngestionInput.getIngestFields();
if (ingestFields != null) {
Arrays
.stream(ingestFields)
.filter(Objects::nonNull)
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
.map(StringUtils::getJsonPath)
.forEach(jsonPath -> {
filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath);
});
}

return filteredFieldMap;
}

/**
* Filters fields in the map where the value contains the specified source index as a prefix.
*
Expand Down Expand Up @@ -159,7 +208,7 @@ protected void batchIngest(
BulkRequest bulkRequest = new BulkRequest();
sourceLines.stream().forEach(jsonStr -> {
Map<String, Object> filteredMapping = isSoleSource
? mlBatchIngestionInput.getFieldMapping()
? filterFieldMappingSoleSource(mlBatchIngestionInput)
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
if (jsonMap.isEmpty()) {
Expand All @@ -174,7 +223,7 @@ protected void batchIngest(
if (!jsonMap.containsKey("_id")) {
throw new IllegalArgumentException("The id filed must be provided to match documents for multiple sources");
}
String id = (String) jsonMap.remove("_id");
String id = String.valueOf(jsonMap.remove("_id"));
UpdateRequest updateRequest = new UpdateRequest(mlBatchIngestionInput.getIndexName(), id).doc(jsonMap).upsert(jsonMap);
bulkRequest.add(updateRequest);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.time.Instant;
Expand Down Expand Up @@ -41,6 +41,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.jayway.jsonpath.PathNotFoundException;

import lombok.extern.log4j.Log4j2;

@Log4j2
Expand Down Expand Up @@ -92,9 +94,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(TRAIN_THREAD_POOL).execute(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
Expand Down Expand Up @@ -125,6 +129,30 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
}
}

protected void executeWithErrorHandling(Runnable task, String taskId) {
try {
task.run();
} catch (PathNotFoundException jsonPathNotFoundException) {
log.error("Error in jsonParse fields", jsonPathNotFoundException);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, jsonPathNotFoundException.getMessage()),
TASK_SEMAPHORE_TIMEOUT,
true
);
} catch (Exception e) {
log.error("Error in ingest, failed to produce a successRate", e);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e)),
TASK_SEMAPHORE_TIMEOUT,
true
);
}
}

protected void handleSuccessRate(double successRate, String taskId) {
if (successRate == 100) {
mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, COMPLETED), 5000, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ public class MachineLearningPlugin extends Plugin
public static final String TRAIN_THREAD_POOL = "opensearch_ml_train";
public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict";
public static final String REMOTE_PREDICT_THREAD_POOL = "opensearch_ml_predict_remote";
public static final String INGEST_THREAD_POOL = "opensearch_ml_ingest";
public static final String REGISTER_THREAD_POOL = "opensearch_ml_register";
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy";
public static final String ML_BASE_URI = "/_plugins/_ml";
Expand Down Expand Up @@ -885,6 +886,14 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
ML_THREAD_POOL_PREFIX + REMOTE_PREDICT_THREAD_POOL,
false
);
FixedExecutorBuilder batchIngestThreadPool = new FixedExecutorBuilder(
settings,
INGEST_THREAD_POOL,
OpenSearchExecutors.allocatedProcessors(settings) * 4,
30,
ML_THREAD_POOL_PREFIX + INGEST_THREAD_POOL,
false
);

return ImmutableList
.of(
Expand All @@ -894,7 +903,8 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
executeThreadPool,
trainThreadPool,
predictThreadPool,
remotePredictThreadPool
remotePredictThreadPool,
batchIngestThreadPool
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ private void runPredict(
&& tensorOutput.getMlModelOutputs() != null
&& !tensorOutput.getMlModelOutputs().isEmpty()) {
ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0);
Integer statusCode = modelOutput.getStatusCode();
if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) {
Map<String, Object> dataAsMap = (Map<String, Object>) modelOutput
.getMlModelTensors()
.get(0)
.getDataAsMap();
if (dataAsMap != null
&& (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) {
if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) {
remoteJob.putAll(dataAsMap);
mlTask.setRemoteJob(remoteJob);
mlTask.setTaskId(null);
Expand Down

0 comments on commit 57051bd

Please sign in to comment.