diff --git a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java index cdb32dde84..3c6df2e387 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java @@ -33,7 +33,6 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction; import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; @@ -146,9 +145,10 @@ protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInpu mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> { if (exceedLimits) { - String error = "exceed maximum BATCH_INGEST Task limits"; + String error = + "Exceeded maximum limit for BATCH_INGEST tasks. To increase the limit, update the plugins.ml_commons.max_batch_ingestion_tasks setting."; log.warn(error + " in task " + mlTask.getTaskId()); - listener.onFailure(new MLLimitExceededException(error)); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS)); } else { mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 17b22d82d5..8c30545222 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -55,7 +55,6 @@ import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; @@ -257,9 +256,10 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener { if (exceedLimits) { - String error = "exceed maximum BATCH_PREDICTION Task limits"; + String error = + "Exceeded maximum limit for BATCH_PREDICTION tasks. To increase the limit, update the plugins.ml_commons.max_batch_inference_tasks setting."; log.warn(error + " in task " + mlTask.getTaskId()); - listener.onFailure(new MLLimitExceededException(error)); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS)); } else { executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener); } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 223f2ce5a5..9aaf416adc 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -9,6 +9,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.*; import static org.mockito.Mockito.spy; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; import java.io.IOException; @@ -229,6 +230,11 @@ public void setup() throws IOException { GetResult getResult = new GetResult(indexName, "1.1.1", 111l, 111l, 111l, true, bytesReference, null, null); getResponse = new GetResponse(getResult); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(false); + return null; + }).when(mlModelManager).checkMaxBatchJobTask(any(MLTask.class), isA(ActionListener.class)); } public void testExecuteTask_OnLocalNode() {