diff --git a/crates/key-value-aws/src/store.rs b/crates/key-value-aws/src/store.rs index e5d155317..4171eb0cf 100644 --- a/crates/key-value-aws/src/store.rs +++ b/crates/key-value-aws/src/store.rs @@ -8,7 +8,7 @@ use aws_config::{BehaviorVersion, Region, SdkConfig}; use aws_credential_types::Credentials; use aws_sdk_dynamodb::{ config::{ProvideCredentials, SharedCredentialsProvider}, - operation::{batch_get_item::BatchGetItemOutput, update_item::UpdateItemOutput}, + operation::{batch_get_item::BatchGetItemOutput, get_item::GetItemOutput}, primitives::Blob, types::{AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, WriteRequest}, Client, @@ -17,8 +17,8 @@ use spin_core::async_trait; use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError}; pub struct KeyValueAwsDynamo { - table: String, - region: String, + table: Arc, + region: Arc, client: async_once_cell::Lazy< Client, std::pin::Pin + Send>>, @@ -96,8 +96,8 @@ impl KeyValueAwsDynamo { }); Ok(Self { - table, - region, + table: Arc::new(table), + region: Arc::new(region), client: async_once_cell::Lazy::from_future(client_fut), }) } @@ -128,18 +128,23 @@ impl StoreManager for KeyValueAwsDynamo { struct AwsDynamoStore { _name: String, client: Client, - table: String, + table: Arc, } struct CompareAndSwap { key: String, client: Client, + table: Arc, bucket_rep: u32, etag: Mutex>, } +/// Primary key in DynamoDB items used for querying items const PK: &str = "PK"; +/// Value key in DynamoDB items storing item value as binary const VAL: &str = "val"; +/// Version key in DynamoDB items used for optimistic locking +const VER: &str = "ver"; #[async_trait] impl Store for AwsDynamoStore { @@ -151,7 +156,7 @@ impl Store for AwsDynamoStore { async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> { self.client .put_item() - .table_name(&self.table) + .table_name(self.table.as_str()) .item(PK, AttributeValue::S(key.to_string())) .item(VAL, AttributeValue::B(Blob::new(value))) .send() @@ -164,7 +169,7 @@ impl Store for AwsDynamoStore { if self.exists(key).await? { self.client .delete_item() - .table_name(&self.table) + .table_name(self.table.as_str()) .key(PK, AttributeValue::S(key.to_string())) .send() .await @@ -192,11 +197,11 @@ impl Store for AwsDynamoStore { )])) } let mut request_items = Some(HashMap::from_iter([( - self.table.clone(), + self.table.to_string(), keys_and_attributes_builder.build().map_err(log_error)?, )])); - loop { + while request_items.is_some() { let BatchGetItemOutput { responses: Some(mut responses), unprocessed_keys, @@ -212,7 +217,7 @@ impl Store for AwsDynamoStore { return Err(Error::Other("No results".into())); }; - if let Some(items) = responses.remove(&self.table) { + if let Some(items) = responses.remove(self.table.as_str()) { for mut item in items { let Some(AttributeValue::S(pk)) = item.remove(PK) else { return Err(Error::Other( @@ -229,12 +234,10 @@ impl Store for AwsDynamoStore { } } - match unprocessed_keys { - None => return Ok(results), - // TODO: break out if we have retried 10+ times? - remaining_keys => request_items = remaining_keys, - } + request_items = unprocessed_keys; } + + Ok(results) } async fn set_many(&self, key_values: Vec<(String, Vec)>) -> Result<(), Error> { @@ -253,9 +256,9 @@ impl Store for AwsDynamoStore { ) } - let mut request_items = Some(HashMap::from_iter([(self.table.clone(), data)])); + let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)])); - loop { + while request_items.is_some() { let results = self .client .batch_write_item() @@ -264,12 +267,10 @@ impl Store for AwsDynamoStore { .await .map_err(log_error)?; - match results.unprocessed_items { - None => return Ok(()), - // TODO: break out if we have retried 10+ times? - remaining_items => request_items = remaining_items, - } + request_items = results.unprocessed_items; } + + Ok(()) } async fn delete_many(&self, keys: Vec) -> Result<(), Error> { @@ -287,30 +288,28 @@ impl Store for AwsDynamoStore { ) } - let mut input = Some(HashMap::from_iter([(self.table.clone(), data)])); + let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)])); - loop { + while request_items.is_some() { let results = self .client .batch_write_item() - .set_request_items(input) + .set_request_items(request_items) .send() .await .map_err(log_error)?; - match results.unprocessed_items { - None => return Ok(()), - // TODO: break out if we have retried 10+ times? - remaining_items => input = remaining_items, - } + request_items = results.unprocessed_items; } + + Ok(()) } async fn increment(&self, key: String, delta: i64) -> Result { let result = self .client .update_item() - .table_name(&self.table) + .table_name(self.table.as_str()) .key(PK, AttributeValue::S(key)) .update_expression("ADD #val :delta") .expression_attribute_names("#val", VAL) @@ -337,6 +336,7 @@ impl Store for AwsDynamoStore { Ok(Arc::new(CompareAndSwap { key: key.to_string(), client: self.client.clone(), + table: self.table.clone(), etag: Mutex::new(None), bucket_rep, })) @@ -346,13 +346,77 @@ impl Store for AwsDynamoStore { #[async_trait] impl Cas for CompareAndSwap { async fn current(&self) -> Result>, Error> { - todo!(); + let GetItemOutput { + item: Some(mut current_item), + .. + } = self + .client + .get_item() + .table_name(self.table.as_str()) + .key( + PK, + aws_sdk_dynamodb::types::AttributeValue::S(self.key.clone()), + ) + .send() + .await + .map_err(log_error)? + else { + return Ok(None); + }; + + if let Some(AttributeValue::B(val)) = current_item.remove(VAL) { + let version = if let Some(AttributeValue::N(ver)) = current_item.remove(VER) { + Some(ver) + } else { + Some(String::from("0")) + }; + self.etag.lock().unwrap().clone_from(&version); + Ok(Some(val.into_inner())) + } else { + Ok(None) + } } /// `swap` updates the value for the key using the etag saved in the `current` function for /// optimistic concurrency. async fn swap(&self, value: Vec) -> Result<(), SwapError> { - todo!(); + let mut update_item = self + .client + .update_item() + .table_name(self.table.as_str()) + .key(PK, AttributeValue::S(self.key.clone())) + .update_expression("SET #val=:val, ADD #ver :increment") + .expression_attribute_names("#val", VAL) + .expression_attribute_names("#ver", VER) + .expression_attribute_values(":val", AttributeValue::B(Blob::new(value))) + .expression_attribute_values(":increment", AttributeValue::N("1".to_owned())) + .return_values(aws_sdk_dynamodb::types::ReturnValue::None); + + let current_version = self.etag.lock().unwrap().clone(); + match current_version { + // Existing item with no version key, update under condition that version key still does not exist in Dynamo when operation is executed + Some(version) if version == "0" => { + update_item = update_item.condition_expression("attribute_not_exists(#ver)"); + } + // Existing item with version key, update under condition that version in Dynamo matches stored version + Some(version) => { + update_item = update_item + .condition_expression("#ver = :ver") + .expression_attribute_values(":ver", AttributeValue::N(version)); + } + // Assume new item, insert under condition that item does not already exist + None => { + update_item = update_item + .condition_expression("attribute_not_exists(#pk)") + .expression_attribute_names("#pk", PK); + } + } + + update_item + .send() + .await + .map(|_| ()) + .map_err(|e| SwapError::CasFailed(format!("{e:?}"))) } async fn bucket_rep(&self) -> u32 { @@ -369,7 +433,7 @@ impl AwsDynamoStore { let response = self .client .get_item() - .table_name(&self.table) + .table_name(self.table.as_str()) .key( PK, aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()), @@ -397,7 +461,7 @@ impl AwsDynamoStore { let mut scan_builder = self .client .scan() - .table_name(&self.table) + .table_name(self.table.as_str()) .projection_expression(PK); if let Some(keys) = last_evaluated_key {