From d0a749159918988546a60b3db50f2986c7fc85c0 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 24 Nov 2023 23:44:52 +0530 Subject: [PATCH] remove return_content field and add more tests Signed-off-by: Bhavana Ramaram --- .../model_group/MLModelGroupGetRequest.java | 6 +- .../GetModelGroupTransportAction.java | 11 +- .../ml/rest/RestMLGetModelGroupAction.java | 4 +- .../ml/action/MLCommonsIntegTestCase.java | 2 +- .../ml/rest/MLCommonsRestTestCase.java | 11 +- .../ml/rest/MLModelGroupRestIT.java | 153 +++++++++++++++++- 6 files changed, 163 insertions(+), 24 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java index 65fad3b78d..265b16b9d1 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java @@ -30,25 +30,21 @@ public class MLModelGroupGetRequest extends ActionRequest { String modelGroupId; - boolean returnContent; @Builder - public MLModelGroupGetRequest(String modelGroupId, boolean returnContent) { + public MLModelGroupGetRequest(String modelGroupId) { this.modelGroupId = modelGroupId; - this.returnContent = returnContent; } public MLModelGroupGetRequest(StreamInput in) throws IOException { super(in); this.modelGroupId = in.readString(); - this.returnContent = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(this.modelGroupId); - out.writeBoolean(returnContent); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index b7fe030e87..a846c9c0f6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -8,7 +8,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; -import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; @@ -27,13 +26,11 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -71,8 +68,7 @@ public GetModelGroupTransportAction( protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.fromActionRequest(request); String modelGroupId = mlModelGroupGetRequest.getModelGroupId(); - FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGroupGetRequest.isReturnContent()); - GetRequest getRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId).fetchSourceContext(fetchSourceContext); + GetRequest getRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -87,8 +83,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/model_groups/" + modelIGroupd, null, "", null); + verifyResponse(function, response); + } + public void searchModelGroups(RestClient client, String query, Consumer> function) throws IOException { Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/model_groups/_search", null, query, null); verifyResponse(function, response); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java b/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java index 92ea494cf5..4f04fe92cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java @@ -7,11 +7,8 @@ package org.opensearch.ml.rest; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Map; - +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.HttpHost; import org.apache.hc.core5.http.message.BasicHeader; @@ -28,8 +25,10 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.utils.TestHelper; -import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; public class MLModelGroupRestIT extends MLCommonsRestTestCase { @@ -1212,4 +1211,144 @@ public void test_search_MatchAllQuery_For_ModelGroups() throws IOException { } } + public void test_get_modelGroup() throws IOException { + mlRegisterModelGroupInput = createRegisterModelGroupInput("testModelGroup1", Arrays.asList("IT"), AccessMode.RESTRICTED, null); + registerModelGroup(user1Client, TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { + String modelGroupId1 = (String) registerModelGroupResult.get("model_group_id"); + assertTrue(registerModelGroupResult.containsKey("model_group_id")); + try { + // User2 successfully gets model group since user2 has IT backend role + getModelGroup( + user2Client, + modelGroupId1, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + + // Admin successfully gets model group + getModelGroup( + client(), + modelGroupId1, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + } catch (IOException e) { + assertNull(e); + } + // User2 fails to get model group + try { + getModelGroup(user3Client, modelGroupId, null); + } catch (Exception e) { + assertEquals(ResponseException.class, e.getClass()); + assertTrue( + Throwables + .getStackTraceAsString(e) + .contains("User doesn't have privilege to perform this operation on this model group") + ); + } + }); + + mlRegisterModelGroupInput = createRegisterModelGroupInput("testModelGroup2", null, AccessMode.PUBLIC, null); + registerModelGroup(user2Client, TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { + String modelGroupId2 = (String) registerModelGroupResult.get("model_group_id"); + assertTrue(registerModelGroupResult.containsKey("model_group_id")); + try { + // User1 successfully gets model group since user2 has IT backend role + getModelGroup( + user1Client, + modelGroupId2, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + + // User3 successfully gets model group + getModelGroup( + user3Client, + modelGroupId2, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + + // User4 successfully gets model group + getModelGroup( + user4Client, + modelGroupId2, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + } catch (IOException e) { + assertNull(e); + } + }); + + mlRegisterModelGroupInput = createRegisterModelGroupInput("testModelGroup3", null, AccessMode.PRIVATE, null); + registerModelGroup(user3Client, TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { + String modelGroupId3 = (String) registerModelGroupResult.get("model_group_id"); + assertTrue(registerModelGroupResult.containsKey("model_group_id")); + try { + // User3 successfully gets model group since user2 has IT backend role + getModelGroup( + user3Client, + modelGroupId3, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + + // Admin successfully gets model group + getModelGroup( + client(), + modelGroupId3, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + } catch (IOException e) { + assertNull(e); + } + // User2 fails to get model group + try { + getModelGroup(user2Client, modelGroupId3, null); + } catch (Exception e) { + assertEquals(ResponseException.class, e.getClass()); + assertTrue( + Throwables + .getStackTraceAsString(e) + .contains("User doesn't have privilege to perform this operation on this model group") + ); + } + }); + + mlRegisterModelGroupInput = createRegisterModelGroupInput("testModelGroup4", null, null, null); + registerModelGroup(client(), TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { + String modelGroupId4 = (String) registerModelGroupResult.get("model_group_id"); + assertTrue(registerModelGroupResult.containsKey("model_group_id")); + try { + // Admin successfully gets model group + getModelGroup( + client(), + modelGroupId4, + getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + ); + } catch (IOException e) { + assertNull(e); + } + + // User1 fails to get model group + try { + getModelGroup(user1Client, modelGroupId4, null); + } catch (Exception e) { + assertEquals(ResponseException.class, e.getClass()); + assertTrue( + Throwables + .getStackTraceAsString(e) + .contains("User doesn't have privilege to perform this operation on this model group") + ); + } + + // User2 fails to get model group + try { + getModelGroup(user2Client, modelGroupId4, null); + } catch (Exception e) { + assertEquals(ResponseException.class, e.getClass()); + assertTrue( + Throwables + .getStackTraceAsString(e) + .contains("User doesn't have privilege to perform this operation on this model group") + ); + } + }); + } + }