Skip to content

Commit

Permalink
update error code to 429 for rate limiting and update logs
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Oct 16, 2024
1 parent a22e926 commit 25dd297
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
Expand Down Expand Up @@ -146,9 +145,10 @@ protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInpu

mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
if (exceedLimits) {
String error = "exceed maximum BATCH_INGEST Task limits";
String error =
"Exceeded maximum limit for BATCH_INGEST tasks. To increase the limit, update the plugins.ml_commons.max_batch_ingestion_tasks setting.";
log.warn(error + " in task " + mlTask.getTaskId());
listener.onFailure(new MLLimitExceededException(error));
listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS));
} else {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
Expand Down Expand Up @@ -257,9 +256,10 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
if (actionType.equals(ActionType.BATCH_PREDICT)) {
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
if (exceedLimits) {
String error = "exceed maximum BATCH_PREDICTION Task limits";
String error =
"Exceeded maximum limit for BATCH_PREDICTION tasks. To increase the limit, update the plugins.ml_commons.max_batch_inference_tasks setting.";
log.warn(error + " in task " + mlTask.getTaskId());
listener.onFailure(new MLLimitExceededException(error));
listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS));
} else {
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.spy;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;

import java.io.IOException;
Expand Down Expand Up @@ -229,6 +230,11 @@ public void setup() throws IOException {

GetResult getResult = new GetResult(indexName, "1.1.1", 111l, 111l, 111l, true, bytesReference, null, null);
getResponse = new GetResponse(getResult);
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(1);
listener.onResponse(false);
return null;
}).when(mlModelManager).checkMaxBatchJobTask(any(MLTask.class), isA(ActionListener.class));
}

public void testExecuteTask_OnLocalNode() {
Expand Down

0 comments on commit 25dd297

Please sign in to comment.