Skip to content

Commit

Permalink
use connector credential in offline batch ingestion
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Oct 10, 2024
1 parent d7e0fe4 commit e6ff618
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,6 @@ default void validateConnectorURL(List<String> urlRegexes) {
}

Map<String, String> getDecryptedHeaders();

Map<String, String> getDecryptedCredential();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,37 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* ML batch ingestion data: index, field mapping and input and out files.
*/
@Getter
public class MLBatchIngestionInput implements ToXContentObject, Writeable {

public static final String INDEX_NAME_FIELD = "index_name";
public static final String FIELD_MAP_FIELD = "field_map";
public static final String INGEST_FIELDS = "ingest_fields";
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
public static final String DATA_SOURCE_FIELD = "data_source";
public static final String CONNECTOR_ID_FIELD = "connector_id";

@Getter
private String indexName;
@Getter
private Map<String, Object> fieldMapping;
@Getter
private String[] ingestFields;
@Getter
private Map<String, Object> dataSources;
@Getter
@Setter
private Map<String, String> credential;
private String connectorId;

@Builder(toBuilder = true)
public MLBatchIngestionInput(
String indexName,
Map<String, Object> fieldMapping,
String[] ingestFields,
Map<String, Object> dataSources,
Map<String, String> credential
Map<String, String> credential,
String connectorId
) {
if (indexName == null) {
throw new IllegalArgumentException(
Expand All @@ -66,6 +67,7 @@ public MLBatchIngestionInput(
this.ingestFields = ingestFields;
this.dataSources = dataSources;
this.credential = credential;
this.connectorId = connectorId;
}

public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
Expand All @@ -74,6 +76,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
String[] ingestFields = null;
Map<String, Object> dataSources = null;
Map<String, String> credential = new HashMap<>();
String connectorId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -93,6 +96,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
case CONNECTOR_CREDENTIAL_FIELD:
credential = parser.mapStrings();
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
case DATA_SOURCE_FIELD:
dataSources = parser.map();
break;
Expand All @@ -101,7 +107,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
break;
}
}
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential, connectorId);
}

@Override
Expand All @@ -119,6 +125,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (credential != null) {
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (dataSources != null) {
builder.field(DATA_SOURCE_FIELD, dataSources);
}
Expand Down Expand Up @@ -147,6 +156,7 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
output.writeOptionalString(connectorId);
if (dataSources != null) {
output.writeBoolean(true);
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
Expand All @@ -166,6 +176,7 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
if (input.readBoolean()) {
credential = input.readMap(s -> s.readString(), s -> s.readString());
}
this.connectorId = input.readOptionalString();
if (input.readBoolean()) {
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
}
Expand Down
15 changes: 15 additions & 0 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

package org.opensearch.ml.engine;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -120,6 +124,17 @@ public MLModel train(Input input) {
return trainable.train(mlInput);
}

public Map<String, String> getConnectorCredential(Connector connector) {
connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
String region = connector.getParameters().get(REGION_FIELD);
if (region != null) {
decryptedCredential.putIfAbsent(REGION_FIELD, region);
}

return decryptedCredential;
}

public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
predictable.initModel(mlModel, params, encryptor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.ingest.Ingestable;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
Expand All @@ -55,6 +56,7 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
public static final String SOURCE = "source";
TransportService transportService;
MLTaskManager mlTaskManager;
MLModelManager mlModelManager;
private final Client client;
private ThreadPool threadPool;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
Expand All @@ -66,13 +68,15 @@ public TransportBatchIngestionAction(
Client client,
MLTaskManager mlTaskManager,
ThreadPool threadPool,
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
this.transportService = transportService;
this.client = client;
this.mlTaskManager = mlTaskManager;
this.threadPool = threadPool;
this.mlModelManager = mlModelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

Expand All @@ -85,44 +89,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
throw new IllegalStateException(OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG);
}
validateBatchIngestInput(mlBatchIngestionInput);
MLTask mlTask = MLTask
.builder()
.async(true)
.taskType(MLTaskType.BATCH_INGEST)
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.build();

mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
try {
mlTask.setTaskId(taskId);
mlTaskManager.add(mlTask);
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
TASK_SEMAPHORE_TIMEOUT,
true

if (mlBatchIngestionInput.getConnectorId() != null && (mlBatchIngestionInput.getCredential() == null || mlBatchIngestionInput.getCredential().isEmpty())) {
mlModelManager.getConnectorCredential(mlBatchIngestionInput.getConnectorId(), ActionListener.wrap(credentialMap -> {

mlBatchIngestionInput.setCredential(credentialMap);
createMLTaskandExecute(mlBatchIngestionInput, listener);
}, e -> {
log.error(e.getMessage());
listener
.onFailure(
new OpenSearchStatusException(
"Fail to fetch credentials from the connector in the batch ingestion input: " + e.getMessage(),
RestStatus.BAD_REQUEST
)
);
listener.onFailure(ex);
}
}, exception -> {
log.error("Failed to create batch ingestion task", exception);
listener.onFailure(exception);
}));
}));
} else {
createMLTaskandExecute(mlBatchIngestionInput, listener);
}
} catch (IllegalArgumentException e) {
log.error(e.getMessage());
listener
Expand All @@ -137,6 +122,47 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
}
}

protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInput, ActionListener<MLBatchIngestionResponse> listener) {
MLTask mlTask = MLTask
.builder()
.async(true)
.taskType(MLTaskType.BATCH_INGEST)
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.build();

mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
try {
mlTask.setTaskId(taskId);
mlTaskManager.add(mlTask);
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
TASK_SEMAPHORE_TIMEOUT,
true
);
listener.onFailure(ex);
}
}, exception -> {
log.error("Failed to create batch ingestion task", exception);
listener.onFailure(exception);
}));
}

protected void executeWithErrorHandling(Runnable task, String taskId) {
try {
task.run();
Expand Down Expand Up @@ -189,6 +215,9 @@ private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInpu
|| mlBatchIngestionInput.getDataSources().isEmpty()) {
throw new IllegalArgumentException("The batch ingest input data source cannot be null");
}
if (mlBatchIngestionInput.getCredential() == null && mlBatchIngestionInput.getConnectorId() == null) {
throw new IllegalArgumentException("The batch ingest credential or connector_id cannot be null");
}
Map<String, Object> dataSources = mlBatchIngestionInput.getDataSources();
if (dataSources.get(TYPE) == null || dataSources.get(SOURCE) == null) {
throw new IllegalArgumentException("The batch ingest input data source is missing data type or source");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,14 @@ private void handleException(FunctionName functionName, String taskId, Exception
mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true);
}

public void getConnectorCredential(String connectorId, ActionListener<Map<String, String>> connectorCredentialListener) {
getConnector(connectorId, ActionListener.wrap(connector -> {
Map<String, String> credential = mlEngine.getConnectorCredential(connector);
connectorCredentialListener.onResponse(credential);
log.info("Completed loading credential in the connector {}", connectorId);
}, connectorCredentialListener::onFailure));
}

/**
* Read model chunks from model index. Concat chunks into a whole model file,
* then load
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
Expand All @@ -63,6 +64,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@Mock
private MLTaskManager mlTaskManager;
@Mock
MLModelManager mlModelManager;
@Mock
private ActionFilters actionFilters;
@Mock
private MLBatchIngestionRequest mlBatchIngestionRequest;
Expand Down Expand Up @@ -90,6 +93,7 @@ public void setup() {
client,
mlTaskManager,
threadPool,
mlModelManager,
mlFeatureEnabledSetting
);

Expand Down

0 comments on commit e6ff618

Please sign in to comment.