Skip to content

Commit

Permalink
[Feature/multi_tenancy] Add ifSeqNo and ifPrimaryTerm for update conc…
Browse files Browse the repository at this point in the history
…urrency (#2605)

* Add ifSeqNo and ifPrimaryTerm for update concurrency

Signed-off-by: Daniel Widdis <[email protected]>

* Add test classes

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored Jul 9, 2024
1 parent 6da6ce6 commit a421f49
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;

import java.io.IOException;
import java.util.Map;
Expand All @@ -19,6 +20,8 @@ public class UpdateDataObjectRequest {
private final String index;
private final String id;
private final String tenantId;
private final Long ifSeqNo;
private final Long ifPrimaryTerm;
private final ToXContentObject dataObject;

/**
Expand All @@ -28,12 +31,16 @@ public class UpdateDataObjectRequest {
* @param index the index location to update the object
* @param id the document id
* @param tenantId the tenant id
* @param ifSeqNo the sequence number to match or null if not required
* @param ifPrimaryTerm the primary term to match or null if not required
* @param dataObject the data object
*/
public UpdateDataObjectRequest(String index, String id, String tenantId, ToXContentObject dataObject) {
public UpdateDataObjectRequest(String index, String id, String tenantId, Long ifSeqNo, Long ifPrimaryTerm, ToXContentObject dataObject) {
this.index = index;
this.id = id;
this.tenantId = tenantId;
this.ifSeqNo = ifSeqNo;
this.ifPrimaryTerm = ifPrimaryTerm;
this.dataObject = dataObject;
}

Expand Down Expand Up @@ -61,6 +68,22 @@ public String tenantId() {
return this.tenantId;
}

/**
* Returns the sequence number to match, or null if no match required
* @return the ifSeqNo
*/
public Long ifSeqNo() {
return ifSeqNo;
}

/**
* Returns the primary term to match, or null if no match required
* @return the ifPrimaryTerm
*/
public Long ifPrimaryTerm() {
return ifPrimaryTerm;
}

/**
* Returns the data object
* @return the data object
Expand All @@ -84,6 +107,8 @@ public static class Builder {
private String index = null;
private String id = null;
private String tenantId = null;
private Long ifSeqNo = null;
private Long ifPrimaryTerm = null;
private ToXContentObject dataObject = null;

/**
Expand Down Expand Up @@ -120,7 +145,35 @@ public Builder tenantId(String tenantId) {
this.tenantId = tenantId;
return this;
}


/**
* Only perform this update request if the document's modification was assigned the given
* sequence number. Must be used in combination with {@link #ifPrimaryTerm(long)}
* <p>
* Sequence number may be represented by a different document versioning key on non-OpenSearch data stores.
*/
public Builder ifSeqNo(long seqNo) {
if (seqNo < 0 && seqNo != UNASSIGNED_SEQ_NO) {
throw new IllegalArgumentException("sequence numbers must be non negative. got [" + seqNo + "].");
}
this.ifSeqNo = seqNo;
return this;
}

/**
* Only performs this update request if the document's last modification was assigned the given
* primary term. Must be used in combination with {@link #ifSeqNo(long)}
* <p>
* Primary term may not be relevant on non-OpenSearch data stores.
*/
public Builder ifPrimaryTerm(long term) {
if (term < 0) {
throw new IllegalArgumentException("primary term must be non negative. got [" + term + "]");
}
this.ifPrimaryTerm = term;
return this;
}

/**
* Add a data object to this builder
* @param dataObject the data object
Expand Down Expand Up @@ -150,7 +203,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
* @return A {@link UpdateDataObjectRequest}
*/
public UpdateDataObjectRequest build() {
return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.dataObject);
if ((ifSeqNo == null) != (ifPrimaryTerm == null)) {
throw new IllegalArgumentException("Either ifSeqNo and ifPrimaryTerm must both be null or both must be non-null.");
}
return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.ifSeqNo, this.ifPrimaryTerm, this.dataObject);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.junit.Before;
import org.junit.Test;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.core.rest.RestStatus;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.sdk.UpdateDataObjectRequest.Builder;

import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;

public class UpdateDataObjectRequestTests {

private String testIndex;
private String testId;
private String testTenantId;
private Long testSeqNo;
private Long testPrimaryTerm;
private ToXContentObject testDataObject;
private Map<String, Object> testDataObjectMap;

Expand All @@ -34,6 +39,8 @@ public void setUp() {
testIndex = "test-index";
testId = "test-id";
testTenantId = "test-tenant-id";
testSeqNo = 42L;
testPrimaryTerm = 6L;
testDataObject = mock(ToXContentObject.class);
testDataObjectMap = Map.of("foo", "bar");
}
Expand All @@ -46,6 +53,8 @@ public void testUpdateDataObjectRequest() {
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObject, request.dataObject());
assertNull(request.ifSeqNo());
assertNull(request.ifPrimaryTerm());
}

@Test
Expand All @@ -57,4 +66,26 @@ public void testUpdateDataObjectMapRequest() {
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObjectMap, XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(XContentType.JSON, request.dataObject()), false));
}

@Test
public void testUpdateDataObjectRequestConcurrency() {
UpdateDataObjectRequest request = UpdateDataObjectRequest.builder().index(testIndex).id(testId).tenantId(testTenantId).dataObject(testDataObject)
.ifSeqNo(testSeqNo).ifPrimaryTerm(testPrimaryTerm).build();

assertEquals(testIndex, request.index());
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObject, request.dataObject());
assertEquals(testSeqNo, request.ifSeqNo());
assertEquals(testPrimaryTerm, request.ifPrimaryTerm());

final Builder badSeqNoBuilder = UpdateDataObjectRequest.builder();
assertThrows(IllegalArgumentException.class, () -> badSeqNoBuilder.ifSeqNo(-99));
final Builder badPrimaryTermBuilder = UpdateDataObjectRequest.builder();
assertThrows(IllegalArgumentException.class, () -> badPrimaryTermBuilder.ifPrimaryTerm(-99));
final Builder onlySeqNoBuilder = UpdateDataObjectRequest.builder().ifSeqNo(testSeqNo);
assertThrows(IllegalArgumentException.class, () -> onlySeqNoBuilder.build());
final Builder onlyPrimaryTermBuilder = UpdateDataObjectRequest.builder().ifPrimaryTerm(testPrimaryTerm);
assertThrows(IllegalArgumentException.class, () -> onlyPrimaryTermBuilder.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.XContentParser;

import static org.junit.Assert.assertEquals;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,16 +615,6 @@ private UpdateDataObjectRequest createUpdateModelGroupRequest(
) {
modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion);
modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
/* Old code here. TODO investigate if we need to add seqNo and primaryTerm to data object request
UpdateRequest updateModelGroupRequest = new UpdateRequest();
updateModelGroupRequest
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.setIfSeqNo(seqNo)
.setIfPrimaryTerm(primaryTerm)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.doc(modelGroupSourceMap);
*/
ToXContentObject dataObject = new ToXContentObject() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Expand All @@ -635,7 +625,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder.endObject();
}
};
return UpdateDataObjectRequest.builder().index(ML_MODEL_GROUP_INDEX).id(modelGroupId).dataObject(dataObject).build();
return UpdateDataObjectRequest
.builder()
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.ifSeqNo(seqNo)
.ifPrimaryTerm(primaryTerm)
.dataObject(dataObject)
.build();
}

private Boolean isModelDeployed(MLModelState mlModelState) {
Expand Down
15 changes: 2 additions & 13 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -377,25 +377,14 @@ public void registerMLRemoteModel(
if (getModelGroupResponse.isExists()) {
Map<String, Object> modelGroupSourceMap = getModelGroupResponse.getSourceAsMap();
int updatedVersion = incrementLatestVersion(modelGroupSourceMap);
/* TODO UpdateDataObjectRequest needs to track response seqNo + primary term
UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest(
modelGroupSourceMap,
modelGroupId,
getModelGroupResponse.getSeqNo(),
getModelGroupResponse.getPrimaryTerm(),
updatedVersion
);
*/
modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion);
modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest
.builder()
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
// TODO need to track these for concurrency
// .setIfSeqNo(seqNo)
// .setIfPrimaryTerm(primaryTerm)
// .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.ifSeqNo(getModelGroupResponse.getSeqNo())
.ifPrimaryTerm(getModelGroupResponse.getPrimaryTerm())
.dataObject(modelGroupSourceMap)
.build();
sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
Expand All @@ -71,8 +72,9 @@ public class DDBOpenSearchClient implements SdkClient {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final String DEFAULT_TENANT = "DEFAULT_TENANT";

private static final String HASH_KEY = "tenant_id";
private static final String RANGE_KEY = "id";
private static final String HASH_KEY = "_tenant_id";
private static final String RANGE_KEY = "_id";
private static final String SEQ_NO_KEY = "_seq_no";

private DynamoDbClient dynamoDbClient;
private RemoteClusterIndicesClient remoteClusterIndicesClient;
Expand Down Expand Up @@ -109,7 +111,11 @@ public CompletionStage<PutDataObjectResponse> putDataObjectAsync(PutDataObjectRe
item.put(RANGE_KEY, AttributeValue.builder().s(id).build());
final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build();

// TODO need to initialize/return SEQ_NO here
// If document doesn't exist, return 0
// If document exists, overwrite and increment and return SEQ_NO
dynamoDbClient.putItem(putItemRequest);
// TODO need to pass seqNo to simulated response
String simulatedIndexResponse = simulateOpenSearchResponse(
request.index(),
request.id(),
Expand Down Expand Up @@ -139,6 +145,7 @@ public CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRe
.ofEntries(
Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()),
Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build())
// TODO need to fetch SEQ_NO_KEY
)
)
.build();
Expand Down Expand Up @@ -187,15 +194,33 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(UpdateDat
Map<String, AttributeValue> updateItem = JsonTransformer.convertJsonObjectToDDBAttributeMap(jsonNode);
updateItem.put(HASH_KEY, AttributeValue.builder().s(tenantId).build());
updateItem.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build());
UpdateItemRequest updateItemRequest = UpdateItemRequest
UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest
.builder()
.tableName(getTableName(request.index()))
.key(updateItem)
.build();
.key(updateItem);
if (request.ifSeqNo() != null) {
// Get current document version and put in attribute map. Ignore primary term on DDB.
int currentSeqNo = jsonNode.has(SEQ_NO_KEY) ? jsonNode.get(SEQ_NO_KEY).asInt() : 0;
updateItemRequestBuilder
.conditionExpression("#seqNo = :currentSeqNo")
.expressionAttributeNames(Map.of("#seqNo", SEQ_NO_KEY))
.expressionAttributeValues(
Map.of(":currentSeqNo", AttributeValue.builder().n(Integer.toString(currentSeqNo)).build())
);
}
UpdateItemRequest updateItemRequest = updateItemRequestBuilder.build();
// TODO need to add an incremented seqNo here
dynamoDbClient.updateItem(updateItemRequest);

// TODO need to pass seqNo to simulated response
String simulatedUpdateResponse = simulateOpenSearchResponse(request.index(), request.id(), source, Map.of("found", true));
return UpdateDataObjectResponse.builder().id(request.id()).parser(createParser(simulatedUpdateResponse)).build();
} catch (ConditionalCheckFailedException ccfe) {
log.error("Document version conflict updating {} in {}: {}", request.id(), request.index(), ccfe.getMessage(), ccfe);
// Rethrow
throw new OpenSearchStatusException(
"Document version conflict updating " + request.id() + " in index " + request.index(),
RestStatus.CONFLICT
);
} catch (IOException e) {
log.error("Error updating {} in {}: {}", request.id(), request.index(), e.getMessage(), e);
// Rethrow unchecked exception on update IOException
Expand Down Expand Up @@ -227,7 +252,12 @@ public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDat
.build();
return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction<DeleteDataObjectResponse>) () -> {
try {
// TODO need to return SEQ_NO here
// If document doesn't exist, increment and return highest seq no ever seen, but we would have to track seqNo here
// If document never existed, return -2 (unassigned) for seq no (probably what we have to do here)
// If document exists, increment and return SEQ_NO
dynamoDbClient.deleteItem(deleteItemRequest);
// TODO need to pass seqNo to simulated response
String simulatedDeleteResponse = simulateOpenSearchResponse(
request.index(),
request.id(),
Expand Down
Loading

0 comments on commit a421f49

Please sign in to comment.