Skip to content

Commit

Permalink
Fix model/connector update API to address security concern (opensearc…
Browse files Browse the repository at this point in the history
…h-project#1595)

* Fix model/connector update API to address appsec concern

Signed-off-by: Sicheng Song <[email protected]>

* Fix compile and build failure

Signed-off-by: Sicheng Song <[email protected]>

* Improve unit test coverage

Signed-off-by: Sicheng Song <[email protected]>

* Fix spotless

Signed-off-by: Sicheng Song <[email protected]>

* Merge update connector feature flag to remote inference feature flag

Signed-off-by: Sicheng Song <[email protected]>

* Fix compile

Signed-off-by: Sicheng Song <[email protected]>

* Fix exception status

Signed-off-by: Sicheng Song <[email protected]>

* Keep fixing exception status

Signed-off-by: Sicheng Song <[email protected]>

* Spotless fix

Signed-off-by: Sicheng Song <[email protected]>

* Add UT on parsing exception

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo authored and ylwu-amzn committed Nov 20, 2023
1 parent 373b694 commit b72cc90
Show file tree
Hide file tree
Showing 19 changed files with 293 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
Expand All @@ -38,12 +39,17 @@ public class MLUpdateModelInputTest {
private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" +
"{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" +
"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}";
private final String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" +
"{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" +
"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}";
private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" +
"{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" +
"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}";
private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" +
"{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" +
"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}";
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -109,6 +115,18 @@ public void parse_Success() throws Exception {
});
}

@Test
public void parse_WithNullFieldWithoutModel() throws Exception {
exceptionRule.expect(IllegalStateException.class);
testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> {
try {
assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput));
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

@Test
public void parse_WithIllegalFieldWithoutModel() throws Exception {
testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
Expand All @@ -19,11 +20,11 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLConnectorGetAction;
import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest;
import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse;
Expand Down Expand Up @@ -79,20 +80,31 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLConn
if (connectorAccessControlHelper.hasPermission(user, mlConnector)) {
actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build());
} else {
actionListener.onFailure(new MLValidationException("You don't have permission to access this connector"));
actionListener
.onFailure(
new OpenSearchStatusException(
"You don't have permission to access this connector",
RestStatus.FORBIDDEN
)
);
}
} catch (Exception e) {
log.error("Failed to parse ml connector" + r.getId(), e);
actionListener.onFailure(e);
}
} else {
actionListener
.onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId));
.onFailure(
new OpenSearchStatusException(
"Failed to find connector with the provided connector id: " + connectorId,
RestStatus.NOT_FOUND
)
);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
log.error("Failed to get connector index", e);
actionListener.onFailure(new IllegalArgumentException("Fail to find connector"));
actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND));
} else {
log.error("Failed to get ML connector " + connectorId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Arrays;
import java.util.List;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.search.SearchRequest;
Expand All @@ -27,13 +28,13 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.engine.MLEngine;
Expand Down Expand Up @@ -136,10 +137,11 @@ private void updateUndeployedConnector(
}
listener
.onFailure(
new MLValidationException(
new OpenSearchStatusException(
searchHits.length
+ " models are still using this connector, please undeploy the models first: "
+ Arrays.toString(modelIds.toArray(new String[0]))
+ Arrays.toString(modelIds.toArray(new String[0])),
RestStatus.BAD_REQUEST
)
);
}
Expand Down
Loading

0 comments on commit b72cc90

Please sign in to comment.