diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index aa2f93235e..86068ad0f9 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -184,4 +184,6 @@ default void validateConnectorURL(List urlRegexes) { } Map getDecryptedHeaders(); + + Map getDecryptedCredential(); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java index fbdc895efc..a6a263f178 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java @@ -20,10 +20,12 @@ import lombok.Builder; import lombok.Getter; +import lombok.Setter; /** * ML batch ingestion data: index, field mapping and input and out files. */ +@Getter public class MLBatchIngestionInput implements ToXContentObject, Writeable { public static final String INDEX_NAME_FIELD = "index_name"; @@ -31,17 +33,15 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable { public static final String INGEST_FIELDS = "ingest_fields"; public static final String CONNECTOR_CREDENTIAL_FIELD = "credential"; public static final String DATA_SOURCE_FIELD = "data_source"; + public static final String CONNECTOR_ID_FIELD = "connector_id"; - @Getter private String indexName; - @Getter private Map fieldMapping; - @Getter private String[] ingestFields; - @Getter private Map dataSources; - @Getter + @Setter private Map credential; + private String connectorId; @Builder(toBuilder = true) public MLBatchIngestionInput( @@ -49,7 +49,8 @@ public MLBatchIngestionInput( Map fieldMapping, String[] ingestFields, Map dataSources, - Map credential + Map credential, + String connectorId ) { if (indexName == null) { throw new IllegalArgumentException( @@ -66,6 +67,7 @@ public MLBatchIngestionInput( this.ingestFields = ingestFields; this.dataSources = dataSources; this.credential = credential; + this.connectorId = connectorId; } public static MLBatchIngestionInput parse(XContentParser parser) throws IOException { @@ -74,6 +76,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept String[] ingestFields = null; Map dataSources = null; Map credential = new HashMap<>(); + String connectorId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -93,6 +96,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept case CONNECTOR_CREDENTIAL_FIELD: credential = parser.mapStrings(); break; + case CONNECTOR_ID_FIELD: + connectorId = parser.text(); + break; case DATA_SOURCE_FIELD: dataSources = parser.map(); break; @@ -101,7 +107,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept break; } } - return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential); + return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential, connectorId); } @Override @@ -119,6 +125,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (credential != null) { builder.field(CONNECTOR_CREDENTIAL_FIELD, credential); } + if (connectorId != null) { + builder.field(CONNECTOR_ID_FIELD, connectorId); + } if (dataSources != null) { builder.field(DATA_SOURCE_FIELD, dataSources); } @@ -147,6 +156,7 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } + output.writeOptionalString(connectorId); if (dataSources != null) { output.writeBoolean(true); output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue); @@ -166,6 +176,7 @@ public MLBatchIngestionInput(StreamInput input) throws IOException { if (input.readBoolean()) { credential = input.readMap(s -> s.readString(), s -> s.readString()); } + this.connectorId = input.readOptionalString(); if (input.readBoolean()) { dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue()); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index b10ad22a69..ecbc05c43e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -5,6 +5,9 @@ package org.opensearch.ml.engine; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; + import java.nio.file.Path; import java.util.Locale; import java.util.Map; @@ -12,6 +15,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -120,6 +124,16 @@ public MLModel train(Input input) { return trainable.train(mlInput); } + public Map getConnectorCredential(Connector connector) { + connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential)); + Map decryptedCredential = connector.getDecryptedCredential(); + String region = connector.getParameters().get(REGION_FIELD); + if (region != null) { + decryptedCredential.putIfAbsent(REGION_FIELD, region); + } + return decryptedCredential; + } + public Predictable deploy(MLModel mlModel, Map params) { Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class); predictable.initModel(mlModel, params, encryptor); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index cf5ccd2152..95c3d5d218 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -16,6 +16,8 @@ import java.io.IOException; import java.nio.file.Path; import java.util.Arrays; +import java.util.Collections; +import java.util.Map; import java.util.UUID; import org.junit.Assert; @@ -24,11 +26,16 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.MockedStatic; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.dataframe.ColumnMeta; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DefaultDataFrame; @@ -47,6 +54,7 @@ import org.opensearch.ml.engine.algorithms.regression.LinearRegression; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.search.SearchModule; // TODO: refactor MLEngineClassLoader's static functions to avoid mockStatic public class MLEngineTest extends MLStaticMockBase { @@ -408,4 +416,32 @@ public void testEncryptMethod() { assertNotEquals(testString, encryptedString); } + @Test + public void testGetConnectorCredential() throws IOException { + String encryptedValue = mlEngine.encrypt("test_key_value"); + String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"region\":\"test region\"},\"credential\":{\"key\":\"" + + encryptedValue + + "\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}]," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + test_connector_string + ); + parser.nextToken(); + + HttpConnector connector = new HttpConnector("http", parser); + Map decryptedCredential = mlEngine.getConnectorCredential(connector); + assertNotNull(decryptedCredential); + assertEquals(decryptedCredential.get("key"), "test_key_value"); + assertEquals(decryptedCredential.get("region"), "test region"); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java index 1f1653b31c..3bb23e3486 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java @@ -166,7 +166,8 @@ public void testFilterFieldMapping_ValidInput_EmptyPrefix() { fieldMap, ingestFields, new HashMap<>(), - new HashMap<>() + new HashMap<>(), + null ); Map result = s3DataIngestion.filterFieldMapping(mlBatchIngestionInput, 0); @@ -190,7 +191,8 @@ public void testFilterFieldMapping_MatchingPrefix() { fieldMap, ingestFields, new HashMap<>(), - new HashMap<>() + new HashMap<>(), + null ); // Act @@ -219,7 +221,8 @@ public void testFilterFieldMappingSoleSource_MatchingPrefix() { fieldMap, ingestFields, new HashMap<>(), - new HashMap<>() + new HashMap<>(), + null ); // Act @@ -292,7 +295,8 @@ public void testBatchIngestSuccess_SoleSource() { fieldMap, ingestFields, new HashMap<>(), - new HashMap<>() + new HashMap<>(), + null ); ActionListener bulkResponseListener = mock(ActionListener.class); s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, true); @@ -318,7 +322,8 @@ public void testBatchIngestSuccess_returnForNullJasonMap() { fieldMap, ingestFields, new HashMap<>(), - new HashMap<>() + new HashMap<>(), + null ); ActionListener bulkResponseListener = mock(ActionListener.class); s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, false); diff --git a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java index 870e8ce97f..aa78c119f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java @@ -15,7 +15,6 @@ import java.time.Instant; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -37,6 +36,7 @@ import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse; import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.ingest.Ingestable; +import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; @@ -56,6 +56,7 @@ public class TransportBatchIngestionAction extends HandledTransportAction { - String taskId = response.getId(); - try { - mlTask.setTaskId(taskId); - mlTaskManager.add(mlTask); - listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name())); - String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE); - Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(Locale.ROOT), client, Client.class); - threadPool.executor(INGEST_THREAD_POOL).execute(() -> { - executeWithErrorHandling(() -> { - double successRate = ingestable.ingest(mlBatchIngestionInput); - handleSuccessRate(successRate, taskId); - }, taskId); - }); - } catch (Exception ex) { - log.error("Failed in batch ingestion", ex); - mlTaskManager - .updateMLTask( - taskId, - Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), - TASK_SEMAPHORE_TIMEOUT, - true + if (mlBatchIngestionInput.getConnectorId() != null + && (mlBatchIngestionInput.getCredential() == null || mlBatchIngestionInput.getCredential().isEmpty())) { + mlModelManager.getConnectorCredential(mlBatchIngestionInput.getConnectorId(), ActionListener.wrap(credentialMap -> { + mlBatchIngestionInput.setCredential(credentialMap); + createMLTaskandExecute(mlBatchIngestionInput, listener); + }, e -> { + log.error(e.getMessage()); + listener + .onFailure( + new OpenSearchStatusException( + "Fail to fetch credentials from the connector in the batch ingestion input: " + e.getMessage(), + RestStatus.BAD_REQUEST + ) ); - listener.onFailure(ex); - } - }, exception -> { - log.error("Failed to create batch ingestion task", exception); - listener.onFailure(exception); - })); + })); + } else { + createMLTaskandExecute(mlBatchIngestionInput, listener); + } } catch (IllegalArgumentException e) { log.error(e.getMessage()); listener @@ -138,6 +121,47 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLTask mlTask = MLTask + .builder() + .async(true) + .taskType(MLTaskType.BATCH_INGEST) + .createTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .state(MLTaskState.CREATED) + .build(); + + mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { + String taskId = response.getId(); + try { + mlTask.setTaskId(taskId); + mlTaskManager.add(mlTask); + listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name())); + String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE); + Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class); + threadPool.executor(INGEST_THREAD_POOL).execute(() -> { + executeWithErrorHandling(() -> { + double successRate = ingestable.ingest(mlBatchIngestionInput); + handleSuccessRate(successRate, taskId); + }, taskId); + }); + } catch (Exception ex) { + log.error("Failed in batch ingestion", ex); + mlTaskManager + .updateMLTask( + taskId, + Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)), + TASK_SEMAPHORE_TIMEOUT, + true + ); + listener.onFailure(ex); + } + }, exception -> { + log.error("Failed to create batch ingestion task", exception); + listener.onFailure(exception); + })); + } + protected void executeWithErrorHandling(Runnable task, String taskId) { try { task.run(); @@ -190,6 +214,9 @@ private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInpu || mlBatchIngestionInput.getDataSources().isEmpty()) { throw new IllegalArgumentException("The batch ingest input data source cannot be null"); } + if (mlBatchIngestionInput.getCredential() == null && mlBatchIngestionInput.getConnectorId() == null) { + throw new IllegalArgumentException("The batch ingest credential or connector_id cannot be null"); + } Map dataSources = mlBatchIngestionInput.getDataSources(); if (dataSources.get(TYPE) == null || dataSources.get(SOURCE) == null) { throw new IllegalArgumentException("The batch ingest input data source is missing data type or source"); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 72508c25c0..a85d173859 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -965,6 +965,14 @@ private void handleException(FunctionName functionName, String taskId, Exception mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); } + public void getConnectorCredential(String connectorId, ActionListener> connectorCredentialListener) { + getConnector(connectorId, ActionListener.wrap(connector -> { + Map credential = mlEngine.getConnectorCredential(connector); + connectorCredentialListener.onResponse(credential); + log.info("Completed loading credential in the connector {}", connectorId); + }, connectorCredentialListener::onFailure)); + } + /** * Read model chunks from model index. Concat chunks into a whole model file, * then load diff --git a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java index f1ab6715f6..3ad8ba2d07 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java @@ -46,6 +46,7 @@ import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest; import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse; +import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.tasks.Task; @@ -63,6 +64,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { @Mock private MLTaskManager mlTaskManager; @Mock + MLModelManager mlModelManager; + @Mock private ActionFilters actionFilters; @Mock private MLBatchIngestionRequest mlBatchIngestionRequest; @@ -79,7 +82,9 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { private TransportBatchIngestionAction batchAction; private MLBatchIngestionInput batchInput; + private MLBatchIngestionInput mlBatchIngestionInputWithConnector; private String[] ingestFields; + private Map credential; @Before public void setup() { @@ -90,6 +95,7 @@ public void setup() { client, mlTaskManager, threadPool, + mlModelManager, mlFeatureEnabledSetting ); @@ -101,7 +107,7 @@ public void setup() { ingestFields = new String[] { "$.id" }; - Map credential = Map + credential = Map .of("region", "us-east-1", "access_key", "some accesskey", "secret_key", "some secret", "session_token", "some token"); Map dataSource = new HashMap<>(); dataSource.put("type", "s3"); @@ -117,6 +123,15 @@ public void setup() { .build(); when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput); + mlBatchIngestionInputWithConnector = MLBatchIngestionInput + .builder() + .indexName("testIndex") + .fieldMapping(fieldMap) + .ingestFields(ingestFields) + .connectorId("test_connector_id") + .dataSources(dataSource) + .build(); + when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(true); } @@ -319,4 +334,33 @@ public void test_doExecute_batchIngestionFailed() { assertEquals("some error", argumentCaptor.getValue().getMessage()); verify(mlTaskManager).updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "some error"), TASK_SEMAPHORE_TIMEOUT, true); } + + public void test_doExecute_withConnector_success() { + when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(mlBatchIngestionInputWithConnector); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(credential); + return null; + }).when(mlModelManager).getConnectorCredential(anyString(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + IndexResponse indexResponse = new IndexResponse(shardId, "taskId", 1, 1, 1, true); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class)); + doReturn(executorService).when(threadPool).executor(INGEST_THREAD_POOL); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + batchAction.doExecute(task, mlBatchIngestionRequest, actionListener); + + verify(actionListener).onResponse(any(MLBatchIngestionResponse.class)); + verify(threadPool).executor(INGEST_THREAD_POOL); + } }