forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Backport to main] update connector API (opensearch-project#1651)
* update connector API Signed-off-by: Xun Zhang <[email protected]> * more ut test coverage Signed-off-by: Xun Zhang <[email protected]> * check connector usage in deployed models before updating connector Signed-off-by: Xun Zhang <[email protected]> --------- Signed-off-by: Xun Zhang <[email protected]> Co-authored-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
4d53db5
commit 5759bf2
Showing
9 changed files
with
987 additions
and
4 deletions.
There are no files selected for viewing
16 changes: 16 additions & 0 deletions
16
...n/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.connector; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.update.UpdateResponse; | ||
|
||
public class MLUpdateConnectorAction extends ActionType<UpdateResponse> { | ||
public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/connectors/update"; | ||
|
||
private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);} | ||
} |
83 changes: 83 additions & 0 deletions
83
.../src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.connector; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
@Getter | ||
public class MLUpdateConnectorRequest extends ActionRequest { | ||
String connectorId; | ||
Map<String, Object> updateContent; | ||
|
||
@Builder | ||
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) { | ||
this.connectorId = connectorId; | ||
this.updateContent = updateContent; | ||
} | ||
|
||
public MLUpdateConnectorRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.connectorId = in.readString(); | ||
this.updateContent = in.readMap(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.connectorId); | ||
out.writeMap(this.getUpdateContent()); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
|
||
if (this.connectorId == null) { | ||
exception = addValidationError("ML connector id can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException { | ||
Map<String, Object> dataAsMap = null; | ||
dataAsMap = parser.map(); | ||
|
||
return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build(); | ||
} | ||
|
||
public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof MLUpdateConnectorRequest) { | ||
return (MLUpdateConnectorRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLUpdateConnectorRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into MLUpdateConnectorRequest", e); | ||
} | ||
} | ||
} |
128 changes: 128 additions & 0 deletions
128
...test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.connector; | ||
|
||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.mockito.Mock; | ||
import org.mockito.MockitoAnnotations; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.rest.RestRequest; | ||
|
||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
import java.util.Map; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertNotSame; | ||
import static org.junit.Assert.assertNull; | ||
import static org.junit.Assert.assertSame; | ||
import static org.mockito.Mockito.when; | ||
|
||
public class MLUpdateConnectorRequestTests { | ||
private String connectorId; | ||
private Map<String, Object> updateContent; | ||
private MLUpdateConnectorRequest mlUpdateConnectorRequest; | ||
|
||
@Mock | ||
XContentParser parser; | ||
|
||
@Before | ||
public void setUp() { | ||
MockitoAnnotations.openMocks(this); | ||
this.connectorId = "test-connector_id"; | ||
this.updateContent = Map.of("description", "new description"); | ||
mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() | ||
.connectorId(connectorId) | ||
.updateContent(updateContent) | ||
.build(); | ||
} | ||
|
||
@Test | ||
public void writeTo_Success() throws IOException { | ||
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); | ||
mlUpdateConnectorRequest.writeTo(bytesStreamOutput); | ||
MLUpdateConnectorRequest parsedUpdateRequest = new MLUpdateConnectorRequest(bytesStreamOutput.bytes().streamInput()); | ||
assertEquals(connectorId, parsedUpdateRequest.getConnectorId()); | ||
assertEquals(updateContent, parsedUpdateRequest.getUpdateContent()); | ||
} | ||
|
||
@Test | ||
public void validate_Success() { | ||
assertNull(mlUpdateConnectorRequest.validate()); | ||
} | ||
|
||
@Test | ||
public void validate_Exception_NullConnectorId() { | ||
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build(); | ||
Exception exception = updateConnectorRequest.validate(); | ||
|
||
assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage()); | ||
} | ||
|
||
@Test | ||
public void parse_success() throws IOException { | ||
RestRequest.Method method = RestRequest.Method.POST; | ||
final Map<String, Object> updatefields = Map.of("version", "new version", "description", "new description"); | ||
when(parser.map()).thenReturn(updatefields); | ||
|
||
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); | ||
assertEquals(updateConnectorRequest.getConnectorId(), connectorId); | ||
assertEquals(updateConnectorRequest.getUpdateContent(), updatefields); | ||
} | ||
|
||
@Test | ||
public void fromActionRequest_Success() { | ||
MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() | ||
.connectorId(connectorId) | ||
.updateContent(updateContent) | ||
.build(); | ||
assertSame(MLUpdateConnectorRequest.fromActionRequest(mlUpdateConnectorRequest), mlUpdateConnectorRequest); | ||
} | ||
|
||
@Test | ||
public void fromActionRequest_Success_fromActionRequest() { | ||
MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() | ||
.connectorId(connectorId) | ||
.updateContent(updateContent) | ||
.build(); | ||
ActionRequest actionRequest = new ActionRequest() { | ||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
mlUpdateConnectorRequest.writeTo(out); | ||
} | ||
}; | ||
MLUpdateConnectorRequest request = MLUpdateConnectorRequest.fromActionRequest(actionRequest); | ||
assertNotSame(request, mlUpdateConnectorRequest); | ||
assertEquals(mlUpdateConnectorRequest.getConnectorId(), request.getConnectorId()); | ||
assertEquals(mlUpdateConnectorRequest.getUpdateContent(), request.getUpdateContent()); | ||
} | ||
|
||
@Test(expected = UncheckedIOException.class) | ||
public void fromActionRequest_IOException() { | ||
ActionRequest actionRequest = new ActionRequest() { | ||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
throw new IOException(); | ||
} | ||
}; | ||
MLUpdateConnectorRequest.fromActionRequest(actionRequest); | ||
} | ||
} |
146 changes: 146 additions & 0 deletions
146
plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.connector; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.DocWriteResponse; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.action.update.UpdateRequest; | ||
import org.opensearch.action.update.UpdateResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.index.query.BoolQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.common.MLModel; | ||
import org.opensearch.ml.common.exception.MLValidationException; | ||
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; | ||
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; | ||
import org.opensearch.ml.helper.ConnectorAccessControlHelper; | ||
import org.opensearch.ml.model.MLModelManager; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.experimental.FieldDefaults; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
public class UpdateConnectorTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> { | ||
Client client; | ||
|
||
ConnectorAccessControlHelper connectorAccessControlHelper; | ||
MLModelManager mlModelManager; | ||
|
||
@Inject | ||
public UpdateConnectorTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
ConnectorAccessControlHelper connectorAccessControlHelper, | ||
MLModelManager mlModelManager | ||
) { | ||
super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); | ||
this.client = client; | ||
this.connectorAccessControlHelper = connectorAccessControlHelper; | ||
this.mlModelManager = mlModelManager; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) { | ||
MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request); | ||
String connectorId = mlUpdateConnectorAction.getConnectorId(); | ||
UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); | ||
updateRequest.doc(mlUpdateConnectorAction.getUpdateContent()); | ||
updateRequest.docAsUpsert(true); | ||
|
||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> { | ||
if (Boolean.TRUE.equals(hasPermission)) { | ||
updateUndeployedConnector(connectorId, updateRequest, listener, context); | ||
} else { | ||
listener | ||
.onFailure( | ||
new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId) | ||
); | ||
} | ||
}, exception -> { | ||
log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); | ||
listener.onFailure(exception); | ||
})); | ||
} catch (Exception e) { | ||
log.error("Failed to update ML connector for connector id {}. Details {}:", connectorId, e); | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
private void updateUndeployedConnector( | ||
String connectorId, | ||
UpdateRequest updateRequest, | ||
ActionListener<UpdateResponse> listener, | ||
ThreadContext.StoredContext context | ||
) { | ||
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); | ||
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); | ||
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); | ||
boolQueryBuilder.must(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); | ||
boolQueryBuilder.must(QueryBuilders.idsQuery().addIds(mlModelManager.getAllModelIds())); | ||
sourceBuilder.query(boolQueryBuilder); | ||
searchRequest.source(sourceBuilder); | ||
|
||
client.search(searchRequest, ActionListener.wrap(searchResponse -> { | ||
SearchHit[] searchHits = searchResponse.getHits().getHits(); | ||
if (searchHits.length == 0) { | ||
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); | ||
} else { | ||
log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); | ||
listener | ||
.onFailure( | ||
new MLValidationException( | ||
searchHits.length + " models are still using this connector, please undeploy the models first!" | ||
) | ||
); | ||
} | ||
}, e -> { | ||
if (e instanceof IndexNotFoundException) { | ||
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); | ||
return; | ||
} | ||
log.error("Failed to update ML connector: " + connectorId, e); | ||
listener.onFailure(e); | ||
|
||
})); | ||
} | ||
|
||
private ActionListener<UpdateResponse> getUpdateResponseListener( | ||
String connectorId, | ||
ActionListener<UpdateResponse> actionListener, | ||
ThreadContext.StoredContext context | ||
) { | ||
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { | ||
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { | ||
log.info("Failed to update the connector with ID: {}", connectorId); | ||
actionListener.onResponse(updateResponse); | ||
return; | ||
} | ||
log.info("Successfully updated the connector with ID: {}", connectorId); | ||
actionListener.onResponse(updateResponse); | ||
}, exception -> { | ||
log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception); | ||
actionListener.onFailure(exception); | ||
}), context::restore); | ||
} | ||
} |
Oops, something went wrong.