Skip to content

Commit

Permalink
Create a failing unit test for bug in path parameter consumption.
Browse files Browse the repository at this point in the history
Signed-off-by: Nathalie Jonathan <[email protected]>
  • Loading branch information
nathaliellenaa committed Dec 23, 2024
1 parent 58903ba commit 14f96c1
Showing 1 changed file with 44 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -209,6 +210,49 @@ public void testRegisterModelRequestWithNullModelID() throws Exception {
assertEquals("2", registerModelInput.getVersion());
}

public void testRegisterModelRequestWithPathParameters() throws Exception {
RestRequest.Method method = RestRequest.Method.POST;
final Map<String, String> params = new HashMap<>();
params.put("model_id", "test-model-123");
params.put("version", "1.0.0");

final Map<String, Object> modelConfig = Map
.of("model_type", "bert", "embedding_dimension", 384, "framework_type", "sentence_transformers", "all_config", "All Config");
final Map<String, Object> model = Map
.of(
"name",
"test_model",
"model_id",
"test_model_with_modelId",
"version",
"1",
"model_group_id",
"modelGroupId",
"url",
"testUrl",
"model_format",
"TORCH_SCRIPT",
"model_config",
modelConfig
);

String requestContent = new Gson().toJson(model);
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withMethod(method)
.withPath("/_plugins/_ml/models/{model_id}/{version}/_register")
.withParams(params)
.withContent(new BytesArray(requestContent), XContentType.JSON)
.build();

restMLRegisterModelAction.handleRequest(request, channel, client);
ArgumentCaptor<MLRegisterModelRequest> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelRequest.class);
verify(client, times(1)).execute(eq(MLRegisterModelAction.INSTANCE), argumentCaptor.capture(), any());
MLRegisterModelInput registerModelInput = argumentCaptor.getValue().getRegisterModelInput();
assertEquals("test_model", registerModelInput.getModelName());
assertEquals("1", registerModelInput.getVersion());
assertEquals("TORCH_SCRIPT", registerModelInput.getModelFormat().toString());
}

private RestRequest getRestRequest() {
RestRequest.Method method = RestRequest.Method.POST;
final Map<String, Object> modelConfig = Map
Expand Down

0 comments on commit 14f96c1

Please sign in to comment.