Skip to content

Commit

Permalink
Refactor initializeMasterKey to use common code
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 20, 2024
1 parent 19b0d4c commit 780be3b
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 61 deletions.
73 changes: 60 additions & 13 deletions src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,56 @@ public Template redactTemplateSecuredFields(User user, Template template) {
* @param listener the action listener
*/
public void initializeMasterKey(@Nullable String tenantId, ActionListener<Boolean> listener) {
// Index has either been created or it already exists, need to check if master key has been initalized already, if not then
// generate
// This is necessary in case of global context index restoration from snapshot, will need to use the same master key to decrypt
// stored credentials
// Config index has already been created or verified
cacheMasterKeyFromConfigIndex(tenantId).thenApply(v -> {
// Key exists and has been cached successfully
listener.onResponse(true);
return null;
}).exceptionally(throwable -> {
Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable);
// The cacheMasterKey method only completes exceptionally with FFE
if (exception instanceof FlowFrameworkException) {
FlowFrameworkException ffe = (FlowFrameworkException) exception;
if (ffe.status() == RestStatus.NOT_FOUND) {
// Key doesn't exist, need to generate and index a new one
generateAndIndexNewMasterKey(tenantId, listener);
} else {
listener.onFailure(ffe);
}
} else {
// Shouldn't get here
listener.onFailure(exception);
}
return null;
});
}

private void generateAndIndexNewMasterKey(String tenantId, ActionListener<Boolean> listener) {
Config config = new Config(generateMasterKey(), Instant.now());
IndexRequest masterKeyIndexRequest = new IndexRequest(CONFIG_INDEX).id(MASTER_KEY)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext();
XContentBuilder builder = XContentFactory.jsonBuilder()
) {
masterKeyIndexRequest.source(config.toXContent(builder, ToXContent.EMPTY_PARAMS));
client.index(masterKeyIndexRequest, ActionListener.wrap(indexResponse -> {
context.restore();
// Set generated key to master
logger.info("Config has been initialized successfully");
setMasterKey(tenantId, config.masterKey());
listener.onResponse(true);
}, indexException -> {
logger.error("Failed to index config", indexException);
listener.onFailure(indexException);
}));
} catch (Exception e) {
logger.error("Failed to index new key in config index", e);
listener.onFailure(new FlowFrameworkException("Failed to index new key in config index", RestStatus.INTERNAL_SERVER_ERROR));
}
}

public void initializeMasterKeyOld(@Nullable String tenantId, ActionListener<Boolean> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, ActionListener.wrap(getResponse -> {
Expand Down Expand Up @@ -358,23 +404,22 @@ CompletableFuture<Void> initializeMasterKeyIfAbsent(@Nullable String tenantId) {
if (this.tenantMasterKeys.containsKey(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID))) {
return CompletableFuture.completedFuture(null);
}
// Key not in map, fetch from config index and store in map
return cacheMasterKeyFromConfigIndex(tenantId);
}

private CompletableFuture<Void> cacheMasterKeyFromConfigIndex(String tenantId) {
// Key not in map
if (!clusterService.state().metadata().hasIndex(CONFIG_INDEX)) {
return CompletableFuture.failedFuture(
new FlowFrameworkException("Config Index has not been initialized", RestStatus.INTERNAL_SERVER_ERROR)
);
}
// Fetch from config index and store in map
return cacheMasterKeyFromConfigIndex(tenantId);
}

private CompletableFuture<Void> cacheMasterKeyFromConfigIndex(String tenantId) {
// This method assumes the config index must exist
final CompletableFuture<Void> resultFuture = new CompletableFuture<>();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
FetchSourceContext fetchSourceContext = new FetchSourceContext(true);
String masterKeyId = MASTER_KEY;
if (tenantId != null) {
masterKeyId = MASTER_KEY + "_" + hashString(tenantId);
}
String masterKeyId = tenantId == null ? MASTER_KEY : MASTER_KEY + "_" + hashString(tenantId);
sdkClient.getDataObjectAsync(
GetDataObjectRequest.builder()
.index(CONFIG_INDEX)
Expand All @@ -390,6 +435,7 @@ private CompletableFuture<Void> cacheMasterKeyFromConfigIndex(String tenantId) {
try {
response = r.parser() == null ? null : GetResponse.fromXContent(r.parser());
if (response.isExists()) {
System.err.println("B: EXISTS");
try (
XContentParser parser = ParseUtils.createXContentParserFromRegistry(
xContentRegistry,
Expand All @@ -402,6 +448,7 @@ private CompletableFuture<Void> cacheMasterKeyFromConfigIndex(String tenantId) {
resultFuture.complete(null);
}
} else {
System.err.println("C: NOT EXISTS");
resultFuture.completeExceptionally(
new FlowFrameworkException("Master key has not been initialized in config index", RestStatus.NOT_FOUND)
);
Expand Down
20 changes: 20 additions & 0 deletions src/test/java/org/opensearch/flowframework/TestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY;
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength;
import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class TestHelpers {

Expand Down Expand Up @@ -191,6 +195,22 @@ public static SearchRequest matchAllRequest() {
}

public static GetResponse createGetResponse(ToXContentObject o, String id, String indexName) throws IOException {
if (o == null) {
GetResponse getResponse = mock(GetResponse.class);
when(getResponse.getId()).thenReturn(MASTER_KEY);
when(getResponse.getSource()).thenReturn(null);
when(getResponse.toXContent(any(XContentBuilder.class), any())).thenAnswer(invocation -> {
XContentBuilder builder = invocation.getArgument(0);
builder.startObject()
.field("_index", indexName)
.field("_id", id)
.field("found", false)
// .nullField("_source")
.endObject();
return builder;
});
return getResponse;
}
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
return new GetResponse(
new GetResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.flowframework.util;

import org.opensearch.Version;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexRequest;
Expand All @@ -21,14 +22,10 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.flowframework.TestHelpers;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.Config;
Expand All @@ -50,6 +47,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -174,50 +172,36 @@ public void testInitializeMasterKeySuccess() throws IOException, InterruptedExce
// Index exists case
// reinitialize with blank master key
this.encryptorUtils = new EncryptorUtils(clusterService, client, sdkClient, xContentRegistry);
BytesReference bytesRef;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
Config config = new Config(masterKey, Instant.now());
XContentBuilder source = config.toXContent(builder, ToXContent.EMPTY_PARAMS);
bytesRef = BytesReference.bytes(source);
}
doAnswer(invocation -> {
ActionListener<GetResponse> getRequestActionListener = invocation.getArgument(1);

// Stub get response for success case
GetResponse getResponse = mock(GetResponse.class);
when(getResponse.isExists()).thenReturn(true);
when(getResponse.getSourceAsBytesRef()).thenReturn(bytesRef);
assertNull(encryptorUtils.getMasterKey(null));

getRequestActionListener.onResponse(getResponse);
return null;
}).when(client).get(any(GetRequest.class), any());
GetResponse getMasterKeyResponse = TestHelpers.createGetResponse(new Config(masterKey, Instant.now()), MASTER_KEY, CONFIG_INDEX);
PlainActionFuture<GetResponse> future = PlainActionFuture.newFuture();
future.onResponse(getMasterKeyResponse);
when(client.get(any(GetRequest.class))).thenReturn(future);

ActionListener<Boolean> listener = ActionListener.wrap(b -> {}, e -> {});
encryptorUtils.initializeMasterKey(null, listener);

CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<Boolean> latchedActionListener = new LatchedActionListener<>(listener, latch);
encryptorUtils.initializeMasterKey(null, latchedActionListener);
latch.await(1, TimeUnit.SECONDS);
assertEquals(masterKey, encryptorUtils.getMasterKey(null));

// Test ifAbsent version
// reinitialize with blank master key
this.encryptorUtils = new EncryptorUtils(clusterService, client, sdkClient, xContentRegistry);
assertNull(encryptorUtils.getMasterKey(null));

GetResponse getMasterKeyResponse = TestHelpers.createGetResponse(new Config(masterKey, Instant.now()), MASTER_KEY, CONFIG_INDEX);
PlainActionFuture<GetResponse> future = PlainActionFuture.newFuture();
future.onResponse(getMasterKeyResponse);
when(client.get(any(GetRequest.class))).thenReturn(future);

CompletableFuture<Void> resultFuture = encryptorUtils.initializeMasterKeyIfAbsent(null);
resultFuture.get(5, TimeUnit.SECONDS);
assertEquals(masterKey, encryptorUtils.getMasterKey(null));

// No index exists case
doAnswer(invocation -> {
ActionListener<GetResponse> getRequestActionListener = invocation.getArgument(1);
GetResponse getResponse = mock(GetResponse.class);
when(getResponse.isExists()).thenReturn(false);
getRequestActionListener.onResponse(getResponse);
return null;
}).when(client).get(any(GetRequest.class), any());
// No key exists case
getMasterKeyResponse = TestHelpers.createGetResponse(null, MASTER_KEY, CONFIG_INDEX);
future = PlainActionFuture.newFuture();
future.onResponse(getMasterKeyResponse);
when(client.get(any(GetRequest.class))).thenReturn(future);

doAnswer(invocation -> {
ActionListener<IndexResponse> indexRequestActionListener = invocation.getArgument(1);
IndexResponse indexResponse = mock(IndexResponse.class);
Expand All @@ -226,8 +210,11 @@ public void testInitializeMasterKeySuccess() throws IOException, InterruptedExce
}).when(client).index(any(IndexRequest.class), any());

listener = ActionListener.wrap(b -> {}, e -> {});
encryptorUtils.initializeMasterKey(null, listener);
latch = new CountDownLatch(1);
latchedActionListener = new LatchedActionListener<>(listener, latch);
encryptorUtils.initializeMasterKey(null, latchedActionListener);
// This will generate a new master key 32 bytes -> base64 encoded
latch.await(1, TimeUnit.SECONDS);
assertNotEquals(masterKey, encryptorUtils.getMasterKey(null));
assertEquals(masterKey.length(), encryptorUtils.getMasterKey(null).length());
}
Expand All @@ -236,19 +223,7 @@ public void testInitializeMasterKeyFailure() throws IOException {
// reinitialize with blank master key
this.encryptorUtils = new EncryptorUtils(clusterService, client, sdkClient, xContentRegistry);

GetResponse getResponse = mock(GetResponse.class);
when(getResponse.getId()).thenReturn(MASTER_KEY);
when(getResponse.getSource()).thenReturn(null);
when(getResponse.toXContent(any(XContentBuilder.class), any())).thenAnswer(invocation -> {
XContentBuilder builder = invocation.getArgument(0);
builder.startObject()
.field("_index", CONFIG_INDEX)
.field("_id", MASTER_KEY)
.field("found", false)
// .nullField("_source")
.endObject();
return builder;
});
GetResponse getResponse = TestHelpers.createGetResponse(null, MASTER_KEY, CONFIG_INDEX);
PlainActionFuture<GetResponse> future = PlainActionFuture.newFuture();
future.onResponse(getResponse);
when(client.get(any(GetRequest.class))).thenReturn(future);
Expand Down

0 comments on commit 780be3b

Please sign in to comment.