Skip to content

Commit

Permalink
check connector usage in deployed models before updating connector
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored and rbhavna committed Nov 16, 2023
1 parent 9c70503 commit d73e3e5
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
package org.opensearch.ml.action.connector;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateRequest;
Expand All @@ -17,9 +19,17 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -33,17 +43,20 @@ public class UpdateConnectorTransportAction extends HandledTransportAction<Actio
Client client;

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelManager mlModelManager;

@Inject
public UpdateConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ConnectorAccessControlHelper connectorAccessControlHelper
ConnectorAccessControlHelper connectorAccessControlHelper,
MLModelManager mlModelManager
) {
super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new);
this.client = client;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.mlModelManager = mlModelManager;
}

@Override
Expand All @@ -57,7 +70,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> {
if (Boolean.TRUE.equals(hasPermission)) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
updateUndeployedConnector(connectorId, updateRequest, listener, context);
} else {
listener
.onFailure(
Expand All @@ -74,6 +87,44 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
}
}

private void updateUndeployedConnector(
String connectorId,
UpdateRequest updateRequest,
ActionListener<UpdateResponse> listener,
ThreadContext.StoredContext context
) {
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
boolQueryBuilder.must(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId));
boolQueryBuilder.must(QueryBuilders.idsQuery().addIds(mlModelManager.getAllModelIds()));
sourceBuilder.query(boolQueryBuilder);
searchRequest.source(sourceBuilder);

client.search(searchRequest, ActionListener.wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
} else {
log.error(searchHits.length + " models are still using this connector, please undeploy the models first!");
listener
.onFailure(
new MLValidationException(
searchHits.length + " models are still using this connector, please undeploy the models first!"
)
);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
return;
}
log.error("Failed to update ML connector: " + connectorId, e);
listener.onFailure(e);

}));
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String connectorId,
ActionListener<UpdateResponse> actionListener,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
Expand All @@ -34,8 +40,14 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -77,17 +89,22 @@ public class TransportUpdateConnectorActionTests extends OpenSearchTestCase {
@Mock
ActionListener<UpdateResponse> actionListener;

@Mock
MLModelManager mlModelManager;

ThreadContext threadContext;

private Settings settings;

private ShardId shardId;

private SearchResponse searchResponse;

private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList
.of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$");

@Before
public void setup() {
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
settings = Settings
.builder()
Expand All @@ -109,12 +126,28 @@ public void setup() {
when(updateRequest.getConnectorId()).thenReturn(connector_id);
when(updateRequest.getUpdateContent()).thenReturn(updateContent);

SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1);
searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);

transportUpdateConnectorAction = new UpdateConnectorTransportAction(
transportService,
actionFilters,
client,
connectorAccessControlHelper
connectorAccessControlHelper,
mlModelManager
);

when(mlModelManager.getAllModelIds()).thenReturn(new String[] {});
shardId = new ShardId(new Index("indexName", "uuid"), 1);
updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED);
}
Expand All @@ -126,6 +159,12 @@ public void test_execute_connectorAccessControl_success() {
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
Expand Down Expand Up @@ -182,6 +221,13 @@ public void test_execute_UpdateWrongStatus() {
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED);
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
Expand All @@ -200,6 +246,12 @@ public void test_execute_UpdateException() {
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("update document failure"));
Expand All @@ -211,4 +263,84 @@ public void test_execute_UpdateException() {
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("update document failure", argumentCaptor.getValue().getMessage());
}

public void test_execute_SearchResponseNotEmpty() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(noneEmptySearchResponse());
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage());
}

public void test_execute_SearchResponseError() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new RuntimeException("Error in Search Request"));
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage());
}

public void test_execute_SearchIndexNotFoundError() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new IndexNotFoundException("Index not found!"));
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener);
verify(actionListener).onResponse(updateResponse);
}

private SearchResponse noneEmptySearchResponse() throws IOException {
String modelContent = "{\"name\":\"Remote_Model\",\"algorithm\":\"Remote\",\"version\":1,\"connector_id\":\"test_id\"}";
SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent));
SearchHits hits = new SearchHits(new SearchHit[] { model }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);

return searchResponse;
}
}

0 comments on commit d73e3e5

Please sign in to comment.