Skip to content

Commit

Permalink
use connector credential in offline batch ingestion
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Oct 10, 2024
1 parent ff6fe67 commit 6cec85c
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,6 @@ default void validateConnectorURL(List<String> urlRegexes) {
}

Map<String, String> getDecryptedHeaders();

Map<String, String> getDecryptedCredential();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,37 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* ML batch ingestion data: index, field mapping and input and out files.
*/
@Getter
public class MLBatchIngestionInput implements ToXContentObject, Writeable {

public static final String INDEX_NAME_FIELD = "index_name";
public static final String FIELD_MAP_FIELD = "field_map";
public static final String INGEST_FIELDS = "ingest_fields";
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
public static final String DATA_SOURCE_FIELD = "data_source";
public static final String CONNECTOR_ID_FIELD = "connector_id";

@Getter
private String indexName;
@Getter
private Map<String, Object> fieldMapping;
@Getter
private String[] ingestFields;
@Getter
private Map<String, Object> dataSources;
@Getter
@Setter
private Map<String, String> credential;
private String connectorId;

@Builder(toBuilder = true)
public MLBatchIngestionInput(
String indexName,
Map<String, Object> fieldMapping,
String[] ingestFields,
Map<String, Object> dataSources,
Map<String, String> credential
Map<String, String> credential,
String connectorId
) {
if (indexName == null) {
throw new IllegalArgumentException(
Expand All @@ -66,6 +67,7 @@ public MLBatchIngestionInput(
this.ingestFields = ingestFields;
this.dataSources = dataSources;
this.credential = credential;
this.connectorId = connectorId;
}

public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
Expand All @@ -74,6 +76,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
String[] ingestFields = null;
Map<String, Object> dataSources = null;
Map<String, String> credential = new HashMap<>();
String connectorId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -93,6 +96,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
case CONNECTOR_CREDENTIAL_FIELD:
credential = parser.mapStrings();
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
case DATA_SOURCE_FIELD:
dataSources = parser.map();
break;
Expand All @@ -101,7 +107,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
break;
}
}
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential, connectorId);
}

@Override
Expand All @@ -119,6 +125,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (credential != null) {
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (dataSources != null) {
builder.field(DATA_SOURCE_FIELD, dataSources);
}
Expand Down Expand Up @@ -147,6 +156,7 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
output.writeOptionalString(connectorId);
if (dataSources != null) {
output.writeBoolean(true);
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
Expand All @@ -166,6 +176,7 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
if (input.readBoolean()) {
credential = input.readMap(s -> s.readString(), s -> s.readString());
}
this.connectorId = input.readOptionalString();
if (input.readBoolean()) {
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
}
Expand Down
14 changes: 14 additions & 0 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

package org.opensearch.ml.engine;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -120,6 +124,16 @@ public MLModel train(Input input) {
return trainable.train(mlInput);
}

public Map<String, String> getConnectorCredential(Connector connector) {
connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
String region = connector.getParameters().get(REGION_FIELD);
if (region != null) {
decryptedCredential.putIfAbsent(REGION_FIELD, region);
}
return decryptedCredential;
}

public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
predictable.initModel(mlModel, params, encryptor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;

import org.junit.Assert;
Expand All @@ -24,11 +26,16 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.MockedStatic;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
Expand All @@ -47,6 +54,7 @@
import org.opensearch.ml.engine.algorithms.regression.LinearRegression;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.search.SearchModule;

// TODO: refactor MLEngineClassLoader's static functions to avoid mockStatic
public class MLEngineTest extends MLStaticMockBase {
Expand Down Expand Up @@ -408,4 +416,32 @@ public void testEncryptMethod() {
assertNotEquals(testString, encryptedString);
}

@Test
public void testGetConnectorCredential() throws IOException {
String encryptedValue = mlEngine.encrypt("test_key_value");
String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"region\":\"test region\"},\"credential\":{\"key\":\""
+ encryptedValue
+ "\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}],"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
test_connector_string
);
parser.nextToken();

HttpConnector connector = new HttpConnector("http", parser);
Map<String, String> decryptedCredential = mlEngine.getConnectorCredential(connector);
assertNotNull(decryptedCredential);
assertEquals(decryptedCredential.get("key"), "test_key_value");
assertEquals(decryptedCredential.get("region"), "test region");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ public void testFilterFieldMapping_ValidInput_EmptyPrefix() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);
Map<String, Object> result = s3DataIngestion.filterFieldMapping(mlBatchIngestionInput, 0);

Expand All @@ -190,7 +191,8 @@ public void testFilterFieldMapping_MatchingPrefix() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);

// Act
Expand Down Expand Up @@ -219,7 +221,8 @@ public void testFilterFieldMappingSoleSource_MatchingPrefix() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);

// Act
Expand Down Expand Up @@ -292,7 +295,8 @@ public void testBatchIngestSuccess_SoleSource() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);
ActionListener<BulkResponse> bulkResponseListener = mock(ActionListener.class);
s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, true);
Expand All @@ -318,7 +322,8 @@ public void testBatchIngestSuccess_returnForNullJasonMap() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);
ActionListener<BulkResponse> bulkResponseListener = mock(ActionListener.class);
s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, false);
Expand Down
Loading

0 comments on commit 6cec85c

Please sign in to comment.