diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 2092b9f4b4..8ab548afa8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -930,6 +930,25 @@ public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOExc return parseResponseToMap(response); } + public Map predictTextEmbeddingModelIgnoreFunctionName(String modelId, MLInput mlInput) throws IOException { + Response response = null; + try { + response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_predict", + null, + TestHelper.toJsonString(mlInput), + null + ); + } catch (ResponseException e) { + log.error(e.getMessage(), e); + response = e.getResponse(); + } + return parseResponseToMap(response); + } + public Consumer> verifyTextEmbeddingModelDeployed() { return (modelProfile) -> { if (modelProfile.containsKey("model_state")) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index d8e21471d9..bf074a1073 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -15,6 +15,7 @@ import org.junit.Before; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -327,9 +328,10 @@ public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exc RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder() .parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512")) + .actionType(ConnectorAction.ActionType.PREDICT) .build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build(); - Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + Map inferenceResult = predictTextEmbeddingModelIgnoreFunctionName(modelId, mlInput); String errorMsg = String .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));