Skip to content

Commit

Permalink
CAS implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Darwin Boersma <[email protected]>
  • Loading branch information
ogghead committed Nov 7, 2024
1 parent 15276d8 commit d7b5ec1
Showing 1 changed file with 100 additions and 36 deletions.
136 changes: 100 additions & 36 deletions crates/key-value-aws/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use aws_config::{BehaviorVersion, Region, SdkConfig};
use aws_credential_types::Credentials;
use aws_sdk_dynamodb::{
config::{ProvideCredentials, SharedCredentialsProvider},
operation::{batch_get_item::BatchGetItemOutput, update_item::UpdateItemOutput},
operation::{batch_get_item::BatchGetItemOutput, get_item::GetItemOutput},
primitives::Blob,
types::{AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, WriteRequest},
Client,
Expand All @@ -17,8 +17,8 @@ use spin_core::async_trait;
use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};

pub struct KeyValueAwsDynamo {
table: String,
region: String,
table: Arc<String>,
region: Arc<String>,
client: async_once_cell::Lazy<
Client,
std::pin::Pin<Box<dyn std::future::Future<Output = Client> + Send>>,
Expand Down Expand Up @@ -96,8 +96,8 @@ impl KeyValueAwsDynamo {
});

Ok(Self {
table,
region,
table: Arc::new(table),
region: Arc::new(region),
client: async_once_cell::Lazy::from_future(client_fut),
})
}
Expand Down Expand Up @@ -128,18 +128,23 @@ impl StoreManager for KeyValueAwsDynamo {
struct AwsDynamoStore {
_name: String,
client: Client,
table: String,
table: Arc<String>,
}

struct CompareAndSwap {
key: String,
client: Client,
table: Arc<String>,
bucket_rep: u32,
etag: Mutex<Option<String>>,
}

/// Primary key in DynamoDB items used for querying items
const PK: &str = "PK";
/// Value key in DynamoDB items storing item value as binary
const VAL: &str = "val";
/// Version key in DynamoDB items used for optimistic locking
const VER: &str = "ver";

#[async_trait]
impl Store for AwsDynamoStore {
Expand All @@ -151,7 +156,7 @@ impl Store for AwsDynamoStore {
async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
self.client
.put_item()
.table_name(&self.table)
.table_name(self.table.as_str())
.item(PK, AttributeValue::S(key.to_string()))
.item(VAL, AttributeValue::B(Blob::new(value)))
.send()
Expand All @@ -164,7 +169,7 @@ impl Store for AwsDynamoStore {
if self.exists(key).await? {
self.client
.delete_item()
.table_name(&self.table)
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(key.to_string()))
.send()
.await
Expand Down Expand Up @@ -192,11 +197,11 @@ impl Store for AwsDynamoStore {
)]))
}
let mut request_items = Some(HashMap::from_iter([(
self.table.clone(),
self.table.to_string(),
keys_and_attributes_builder.build().map_err(log_error)?,
)]));

loop {
while request_items.is_some() {
let BatchGetItemOutput {
responses: Some(mut responses),
unprocessed_keys,
Expand All @@ -212,7 +217,7 @@ impl Store for AwsDynamoStore {
return Err(Error::Other("No results".into()));
};

if let Some(items) = responses.remove(&self.table) {
if let Some(items) = responses.remove(self.table.as_str()) {
for mut item in items {
let Some(AttributeValue::S(pk)) = item.remove(PK) else {
return Err(Error::Other(
Expand All @@ -229,12 +234,10 @@ impl Store for AwsDynamoStore {
}
}

match unprocessed_keys {
None => return Ok(results),
// TODO: break out if we have retried 10+ times?
remaining_keys => request_items = remaining_keys,
}
request_items = unprocessed_keys;
}

Ok(results)
}

async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
Expand All @@ -253,9 +256,9 @@ impl Store for AwsDynamoStore {
)
}

let mut request_items = Some(HashMap::from_iter([(self.table.clone(), data)]));
let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));

loop {
while request_items.is_some() {
let results = self
.client
.batch_write_item()
Expand All @@ -264,12 +267,10 @@ impl Store for AwsDynamoStore {
.await
.map_err(log_error)?;

match results.unprocessed_items {
None => return Ok(()),
// TODO: break out if we have retried 10+ times?
remaining_items => request_items = remaining_items,
}
request_items = results.unprocessed_items;
}

Ok(())
}

async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
Expand All @@ -287,30 +288,28 @@ impl Store for AwsDynamoStore {
)
}

let mut input = Some(HashMap::from_iter([(self.table.clone(), data)]));
let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));

loop {
while request_items.is_some() {
let results = self
.client
.batch_write_item()
.set_request_items(input)
.set_request_items(request_items)
.send()
.await
.map_err(log_error)?;

match results.unprocessed_items {
None => return Ok(()),
// TODO: break out if we have retried 10+ times?
remaining_items => input = remaining_items,
}
request_items = results.unprocessed_items;
}

Ok(())
}

async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
let result = self
.client
.update_item()
.table_name(&self.table)
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(key))
.update_expression("ADD #val :delta")
.expression_attribute_names("#val", VAL)
Expand All @@ -337,6 +336,7 @@ impl Store for AwsDynamoStore {
Ok(Arc::new(CompareAndSwap {
key: key.to_string(),
client: self.client.clone(),
table: self.table.clone(),
etag: Mutex::new(None),
bucket_rep,
}))
Expand All @@ -346,13 +346,77 @@ impl Store for AwsDynamoStore {
#[async_trait]
impl Cas for CompareAndSwap {
async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
todo!();
let GetItemOutput {
item: Some(mut current_item),
..
} = self
.client
.get_item()
.table_name(self.table.as_str())
.key(
PK,
aws_sdk_dynamodb::types::AttributeValue::S(self.key.clone()),
)
.send()
.await
.map_err(log_error)?
else {
return Ok(None);
};

if let Some(AttributeValue::B(val)) = current_item.remove(VAL) {
let version = if let Some(AttributeValue::N(ver)) = current_item.remove(VER) {
Some(ver)
} else {
Some(String::from("0"))
};
self.etag.lock().unwrap().clone_from(&version);
Ok(Some(val.into_inner()))
} else {
Ok(None)
}
}

/// `swap` updates the value for the key using the etag saved in the `current` function for
/// optimistic concurrency.
async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
todo!();
let mut update_item = self
.client
.update_item()
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(self.key.clone()))
.update_expression("SET #val=:val, ADD #ver :increment")
.expression_attribute_names("#val", VAL)
.expression_attribute_names("#ver", VER)
.expression_attribute_values(":val", AttributeValue::B(Blob::new(value)))
.expression_attribute_values(":increment", AttributeValue::N("1".to_owned()))
.return_values(aws_sdk_dynamodb::types::ReturnValue::None);

let current_version = self.etag.lock().unwrap().clone();
match current_version {
// Existing item with no version key, update under condition that version key still does not exist in Dynamo when operation is executed
Some(version) if version == "0" => {
update_item = update_item.condition_expression("attribute_not_exists(#ver)");
}
// Existing item with version key, update under condition that version in Dynamo matches stored version
Some(version) => {
update_item = update_item
.condition_expression("#ver = :ver")
.expression_attribute_values(":ver", AttributeValue::N(version));
}
// Assume new item, insert under condition that item does not already exist
None => {
update_item = update_item
.condition_expression("attribute_not_exists(#pk)")
.expression_attribute_names("#pk", PK);
}
}

update_item
.send()
.await
.map(|_| ())
.map_err(|e| SwapError::CasFailed(format!("{e:?}")))
}

async fn bucket_rep(&self) -> u32 {
Expand All @@ -369,7 +433,7 @@ impl AwsDynamoStore {
let response = self
.client
.get_item()
.table_name(&self.table)
.table_name(self.table.as_str())
.key(
PK,
aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
Expand Down Expand Up @@ -397,7 +461,7 @@ impl AwsDynamoStore {
let mut scan_builder = self
.client
.scan()
.table_name(&self.table)
.table_name(self.table.as_str())
.projection_expression(PK);

if let Some(keys) = last_evaluated_key {
Expand Down

0 comments on commit d7b5ec1

Please sign in to comment.