Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to 2.17] use connector credential in offline batch ingestion (#2989) (#3095) #3120

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,6 @@ default void validateConnectorURL(List<String> urlRegexes) {
}

Map<String, String> getDecryptedHeaders();

Map<String, String> getDecryptedCredential();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,37 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* ML batch ingestion data: index, field mapping and input and out files.
*/
@Getter
public class MLBatchIngestionInput implements ToXContentObject, Writeable {

public static final String INDEX_NAME_FIELD = "index_name";
public static final String FIELD_MAP_FIELD = "field_map";
public static final String INGEST_FIELDS = "ingest_fields";
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
public static final String DATA_SOURCE_FIELD = "data_source";
public static final String CONNECTOR_ID_FIELD = "connector_id";

@Getter
private String indexName;
@Getter
private Map<String, Object> fieldMapping;
@Getter
private String[] ingestFields;
@Getter
private Map<String, Object> dataSources;
@Getter
@Setter
private Map<String, String> credential;
private String connectorId;

@Builder(toBuilder = true)
public MLBatchIngestionInput(
String indexName,
Map<String, Object> fieldMapping,
String[] ingestFields,
Map<String, Object> dataSources,
Map<String, String> credential
Map<String, String> credential,
String connectorId
) {
if (indexName == null) {
throw new IllegalArgumentException(
Expand All @@ -66,6 +67,7 @@ public MLBatchIngestionInput(
this.ingestFields = ingestFields;
this.dataSources = dataSources;
this.credential = credential;
this.connectorId = connectorId;
}

public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
Expand All @@ -74,6 +76,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
String[] ingestFields = null;
Map<String, Object> dataSources = null;
Map<String, String> credential = new HashMap<>();
String connectorId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -93,6 +96,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
case CONNECTOR_CREDENTIAL_FIELD:
credential = parser.mapStrings();
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
case DATA_SOURCE_FIELD:
dataSources = parser.map();
break;
Expand All @@ -101,7 +107,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
break;
}
}
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential, connectorId);
}

@Override
Expand All @@ -119,6 +125,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (credential != null) {
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (dataSources != null) {
builder.field(DATA_SOURCE_FIELD, dataSources);
}
Expand Down Expand Up @@ -147,6 +156,7 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
output.writeOptionalString(connectorId);
if (dataSources != null) {
output.writeBoolean(true);
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
Expand All @@ -166,6 +176,7 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
if (input.readBoolean()) {
credential = input.readMap(s -> s.readString(), s -> s.readString());
}
this.connectorId = input.readOptionalString();
if (input.readBoolean()) {
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

package org.opensearch.ml.engine;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -120,6 +124,16 @@ public MLModel train(Input input) {
return trainable.train(mlInput);
}

public Map<String, String> getConnectorCredential(Connector connector) {
connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
String region = connector.getParameters().get(REGION_FIELD);
if (region != null) {
decryptedCredential.putIfAbsent(REGION_FIELD, region);
}
return decryptedCredential;
}

public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
predictable.initModel(mlModel, params, encryptor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;

import org.junit.Assert;
Expand All @@ -24,11 +26,16 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.MockedStatic;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
Expand All @@ -47,6 +54,7 @@
import org.opensearch.ml.engine.algorithms.regression.LinearRegression;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.search.SearchModule;

// TODO: refactor MLEngineClassLoader's static functions to avoid mockStatic
public class MLEngineTest extends MLStaticMockBase {
Expand Down Expand Up @@ -408,4 +416,32 @@ public void testEncryptMethod() {
assertNotEquals(testString, encryptedString);
}

@Test
public void testGetConnectorCredential() throws IOException {
String encryptedValue = mlEngine.encrypt("test_key_value");
String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"region\":\"test region\"},\"credential\":{\"key\":\""
+ encryptedValue
+ "\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}],"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
test_connector_string
);
parser.nextToken();

HttpConnector connector = new HttpConnector("http", parser);
Map<String, String> decryptedCredential = mlEngine.getConnectorCredential(connector);
assertNotNull(decryptedCredential);
assertEquals(decryptedCredential.get("key"), "test_key_value");
assertEquals(decryptedCredential.get("region"), "test region");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ public void testFilterFieldMapping_ValidInput_EmptyPrefix() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);
Map<String, Object> result = s3DataIngestion.filterFieldMapping(mlBatchIngestionInput, 0);

Expand All @@ -190,7 +191,8 @@ public void testFilterFieldMapping_MatchingPrefix() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);

// Act
Expand Down Expand Up @@ -219,7 +221,8 @@ public void testFilterFieldMappingSoleSource_MatchingPrefix() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);

// Act
Expand Down Expand Up @@ -292,7 +295,8 @@ public void testBatchIngestSuccess_SoleSource() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);
ActionListener<BulkResponse> bulkResponseListener = mock(ActionListener.class);
s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, true);
Expand All @@ -318,7 +322,8 @@ public void testBatchIngestSuccess_returnForNullJasonMap() {
fieldMap,
ingestFields,
new HashMap<>(),
new HashMap<>()
new HashMap<>(),
null
);
ActionListener<BulkResponse> bulkResponseListener = mock(ActionListener.class);
s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, false);
Expand Down
Loading
Loading