diff --git a/radix-clis/src/rtmd/mod.rs b/radix-clis/src/rtmd/mod.rs index c9cbb928c34..05553ff999d 100644 --- a/radix-clis/src/rtmd/mod.rs +++ b/radix-clis/src/rtmd/mod.rs @@ -61,54 +61,55 @@ pub fn run() -> Result<(), String> { None => NetworkDefinition::simulator(), }; - let (manifest_instructions, blobs) = - match manifest_decode::(&content).map_err(Error::DecodeError) { - Ok(manifest) => { - let blobs: Vec> = manifest - .blobs - .values() - .into_iter() - .map(|item| item.to_owned()) - .collect(); - (manifest.instructions, blobs) - } - Err(e) => { - // try to decode versioned transaction - match manifest_decode::(&content) { - Ok(manifest) => { - let (manifest_instructions, blobs) = match manifest { - VersionedTransactionPayload::IntentV1 { - instructions, - blobs, - .. - } => (instructions.0, blobs.blobs), - VersionedTransactionPayload::SignedIntentV1 { intent, .. } => { - (intent.instructions.0, intent.blobs.blobs) - } - VersionedTransactionPayload::NotarizedTransactionV1 { - signed_intent, - .. - } => ( - signed_intent.intent.instructions.0, - signed_intent.intent.blobs.blobs, - ), - VersionedTransactionPayload::SystemTransactionV1 { - instructions, - blobs, - .. - } => (instructions.0, blobs.blobs), - }; + let (manifest_instructions, blobs) = match manifest_decode::(&content) + .map_err(Error::DecodeError) + { + Ok(manifest) => { + let blobs: Vec> = manifest + .blobs + .values() + .into_iter() + .map(|item| item.to_owned()) + .collect(); + (manifest.instructions, blobs) + } + Err(e) => { + // try to decode versioned transaction + match manifest_decode::(&content) { + Ok(manifest) => { + let (manifest_instructions, blobs) = match manifest { + VersionedTransactionPayload::IntentV1(IntentV1 { + instructions, + blobs, + .. + }) => (instructions.0, blobs.blobs), + VersionedTransactionPayload::SignedIntentV1(SignedIntentV1 { + intent, + .. + }) => (intent.instructions.0, intent.blobs.blobs), + VersionedTransactionPayload::NotarizedTransactionV1( + NotarizedTransactionV1 { signed_intent, .. }, + ) => ( + signed_intent.intent.instructions.0, + signed_intent.intent.blobs.blobs, + ), + VersionedTransactionPayload::SystemTransactionV1(SystemTransactionV1 { + instructions, + blobs, + .. + }) => (instructions.0, blobs.blobs), + }; - let blobs: Vec> = blobs.into_iter().map(|item| item.0).collect(); - (manifest_instructions, blobs) - } - Err(_) => { - // return original error - return Err(e.into()); - } + let blobs: Vec> = blobs.into_iter().map(|item| item.0).collect(); + (manifest_instructions, blobs) + } + Err(_) => { + // return original error + return Err(e.into()); } } - }; + } + }; validate_call_arguments_to_native_components(&manifest_instructions) .map_err(Error::InstructionSchemaValidationError)?; diff --git a/radix-sbor-derive/src/manifest_categorize.rs b/radix-sbor-derive/src/manifest_categorize.rs index c0a5b18c968..ec2e2bad81b 100644 --- a/radix-sbor-derive/src/manifest_categorize.rs +++ b/radix-sbor-derive/src/manifest_categorize.rs @@ -72,9 +72,9 @@ mod tests { fn get_length(&self) -> usize { match self { - Self::A { .. } => 1, - Self::B(_) => 1, - Self::C => 0, + Self::A { .. } => 1usize, + Self::B(_) => 1usize, + Self::C => 0usize, } } } diff --git a/radix-sbor-derive/src/manifest_decode.rs b/radix-sbor-derive/src/manifest_decode.rs index 6adbd694983..c2407b83218 100644 --- a/radix-sbor-derive/src/manifest_decode.rs +++ b/radix-sbor-derive/src/manifest_decode.rs @@ -37,7 +37,7 @@ mod tests { ) -> Result { use sbor::{self, Decode}; decoder.check_preloaded_value_kind(value_kind, sbor::ValueKind::Tuple)?; - decoder.read_and_check_size(0)?; + decoder.read_and_check_size(0usize)?; Ok(Self {}) } } @@ -72,17 +72,17 @@ mod tests { #[deny(unreachable_patterns)] match discriminator { 0u8 => { - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self::A { named: decoder.decode::()?, }) }, 1u8 => { - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self::B(decoder.decode::()?,)) }, 2u8 => { - decoder.read_and_check_size(0)?; + decoder.read_and_check_size(0usize)?; Ok(Self::C) }, _ => Err(sbor::DecodeError::UnknownDiscriminator(discriminator)) diff --git a/radix-sbor-derive/src/manifest_encode.rs b/radix-sbor-derive/src/manifest_encode.rs index 52967329e65..0c5708c1118 100644 --- a/radix-sbor-derive/src/manifest_encode.rs +++ b/radix-sbor-derive/src/manifest_encode.rs @@ -37,7 +37,7 @@ mod tests { #[inline] fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; - encoder.write_size(0)?; + encoder.write_size(0usize)?; Ok(()) } } @@ -71,17 +71,17 @@ mod tests { match self { Self::A { named, .. } => { encoder.write_discriminator(0u8)?; - encoder.write_size(1)?; + encoder.write_size(1usize)?; encoder.encode(named)?; } Self::B(a0) => { encoder.write_discriminator(1u8)?; - encoder.write_size(1)?; + encoder.write_size(1usize)?; encoder.encode(a0)?; } Self::C => { encoder.write_discriminator(2u8)?; - encoder.write_size(0)?; + encoder.write_size(0usize)?; } } Ok(()) diff --git a/radix-sbor-derive/src/scrypto_categorize.rs b/radix-sbor-derive/src/scrypto_categorize.rs index d9605046aec..70ae7b0f2eb 100644 --- a/radix-sbor-derive/src/scrypto_categorize.rs +++ b/radix-sbor-derive/src/scrypto_categorize.rs @@ -72,9 +72,9 @@ mod tests { fn get_length(&self) -> usize { match self { - Self::A { .. } => 1, - Self::B(_) => 1, - Self::C => 0, + Self::A { .. } => 1usize, + Self::B(_) => 1usize, + Self::C => 0usize, } } } diff --git a/radix-sbor-derive/src/scrypto_decode.rs b/radix-sbor-derive/src/scrypto_decode.rs index b463564caa4..00caf692659 100644 --- a/radix-sbor-derive/src/scrypto_decode.rs +++ b/radix-sbor-derive/src/scrypto_decode.rs @@ -37,7 +37,7 @@ mod tests { ) -> Result { use sbor::{self, Decode}; decoder.check_preloaded_value_kind(value_kind, sbor::ValueKind::Tuple)?; - decoder.read_and_check_size(0)?; + decoder.read_and_check_size(0usize)?; Ok(Self {}) } } @@ -72,17 +72,17 @@ mod tests { #[deny(unreachable_patterns)] match discriminator { 0u8 => { - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self::A { named: decoder.decode::()?, }) }, 1u8 => { - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self::B(decoder.decode::()?,)) }, 2u8 => { - decoder.read_and_check_size(0)?; + decoder.read_and_check_size(0usize)?; Ok(Self::C) }, _ => Err(sbor::DecodeError::UnknownDiscriminator(discriminator)) diff --git a/radix-sbor-derive/src/scrypto_describe.rs b/radix-sbor-derive/src/scrypto_describe.rs index b0215006290..310a63d7829 100644 --- a/radix-sbor-derive/src/scrypto_describe.rs +++ b/radix-sbor-derive/src/scrypto_describe.rs @@ -108,9 +108,9 @@ mod tests { sbor::TypeData::enum_variants( "MyEnum", sbor :: rust :: prelude :: indexmap ! [ - 0u8 => sbor :: TypeData :: struct_with_named_fields ("A", sbor :: rust :: vec ! [("named", < T as sbor :: Describe < radix_common::data::scrypto::ScryptoCustomTypeKind >> :: TYPE_ID) ,] ,) , - 1u8 => sbor :: TypeData :: struct_with_unnamed_fields ("B", sbor :: rust :: vec ! [< String as sbor :: Describe < radix_common::data::scrypto::ScryptoCustomTypeKind >> :: TYPE_ID ,] ,) , - 2u8 => sbor :: TypeData :: struct_with_unit_fields ("C") , + 0u8 => { sbor :: TypeData :: struct_with_named_fields ("A", sbor :: rust :: vec ! [("named", < T as sbor :: Describe < radix_common::data::scrypto::ScryptoCustomTypeKind >> :: TYPE_ID) ,] ,) }, + 1u8 => { sbor :: TypeData :: struct_with_unnamed_fields ("B", sbor :: rust :: vec ! [< String as sbor :: Describe < radix_common::data::scrypto::ScryptoCustomTypeKind >> :: TYPE_ID ,] ,) }, + 2u8 => { sbor :: TypeData :: struct_with_unit_fields ("C") }, ], ) } diff --git a/radix-sbor-derive/src/scrypto_encode.rs b/radix-sbor-derive/src/scrypto_encode.rs index eeec7226c5b..d878c79bcc7 100644 --- a/radix-sbor-derive/src/scrypto_encode.rs +++ b/radix-sbor-derive/src/scrypto_encode.rs @@ -37,7 +37,7 @@ mod tests { #[inline] fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; - encoder.write_size(0)?; + encoder.write_size(0usize)?; Ok(()) } } @@ -71,17 +71,17 @@ mod tests { match self { Self::A { named, .. } => { encoder.write_discriminator(0u8)?; - encoder.write_size(1)?; + encoder.write_size(1usize)?; encoder.encode(named)?; } Self::B(a0) => { encoder.write_discriminator(1u8)?; - encoder.write_size(1)?; + encoder.write_size(1usize)?; encoder.encode(a0)?; } Self::C => { encoder.write_discriminator(2u8)?; - encoder.write_size(0)?; + encoder.write_size(0usize)?; } } Ok(()) diff --git a/radix-transactions/src/model/versioned.rs b/radix-transactions/src/model/versioned.rs index a91900e52fe..21a3286bb76 100644 --- a/radix-transactions/src/model/versioned.rs +++ b/radix-transactions/src/model/versioned.rs @@ -30,36 +30,19 @@ const V1_PREVIEW_TRANSACTION: u8 = 6; const V1_LEDGER_TRANSACTION: u8 = 7; const V1_FLASH_TRANSACTION: u8 = 8; -// TODO - change this to use #[flatten] when REP-84 is out -/// An enum of a variety of different transaction payload types +/// An enum of a variety of different transaction payload types. /// This might see use in (eg) the Node's transaction parse API. /// These represent the different transaction types. #[derive(Clone, Debug, Eq, PartialEq, ManifestSbor)] pub enum VersionedTransactionPayload { - #[sbor(discriminator(V1_INTENT))] - IntentV1 { - header: TransactionHeaderV1, - instructions: InstructionsV1, - blobs: BlobsV1, - message: MessageV1, - }, - #[sbor(discriminator(V1_SIGNED_INTENT))] - SignedIntentV1 { - intent: IntentV1, - intent_signatures: IntentSignaturesV1, - }, - #[sbor(discriminator(V1_NOTARIZED_TRANSACTION))] - NotarizedTransactionV1 { - signed_intent: SignedIntentV1, - notary_signature: NotarySignatureV1, - }, - #[sbor(discriminator(V1_SYSTEM_TRANSACTION))] - SystemTransactionV1 { - instructions: InstructionsV1, - blobs: BlobsV1, - pre_allocated_addresses: Vec, - hash_for_execution: Hash, - }, + #[sbor(flatten, discriminator(V1_INTENT))] + IntentV1(IntentV1), + #[sbor(flatten, discriminator(V1_SIGNED_INTENT))] + SignedIntentV1(SignedIntentV1), + #[sbor(flatten, discriminator(V1_NOTARIZED_TRANSACTION))] + NotarizedTransactionV1(NotarizedTransactionV1), + #[sbor(flatten, discriminator(V1_SYSTEM_TRANSACTION))] + SystemTransactionV1(SystemTransactionV1), } #[cfg(test)] @@ -153,12 +136,7 @@ mod tests { manifest_decode::(&intent_payload_bytes).unwrap(); assert_eq!( intent_as_versioned, - VersionedTransactionPayload::IntentV1 { - header: header_v1, - instructions: instructions_v1, - blobs: blobs_v1, - message: message_v1, - } + VersionedTransactionPayload::IntentV1(intent_v1.clone()) ); let prepared_intent = @@ -212,10 +190,7 @@ mod tests { manifest_decode::(&signed_intent_payload_bytes).unwrap(); assert_eq!( signed_intent_as_versioned, - VersionedTransactionPayload::SignedIntentV1 { - intent: intent_v1, - intent_signatures: intent_signatures_v1, - } + VersionedTransactionPayload::SignedIntentV1(signed_intent_v1.clone()) ); let prepared_signed_intent = @@ -272,10 +247,7 @@ mod tests { .unwrap(); assert_eq!( notarized_transaction_as_versioned, - VersionedTransactionPayload::NotarizedTransactionV1 { - signed_intent: signed_intent_v1, - notary_signature: notary_signature_v1, - } + VersionedTransactionPayload::NotarizedTransactionV1(notarized_transaction_v1) ); let prepared_notarized_transaction = PreparedNotarizedTransactionV1::prepare_from_payload( @@ -364,12 +336,7 @@ mod tests { .unwrap(); assert_eq!( system_transaction_as_versioned, - VersionedTransactionPayload::SystemTransactionV1 { - instructions: instructions_v1, - blobs: blobs_v1, - pre_allocated_addresses: pre_allocated_addresses_v1, - hash_for_execution - } + VersionedTransactionPayload::SystemTransactionV1(system_transaction_v1) ); let prepared_system_transaction = diff --git a/sbor-derive-common/src/categorize.rs b/sbor-derive-common/src/categorize.rs index 399b4ff1b53..694628b39fd 100644 --- a/sbor-derive-common/src/categorize.rs +++ b/sbor-derive-common/src/categorize.rs @@ -52,11 +52,8 @@ fn handle_normal_categorize( let output = match data { Data::Struct(s) => { - let FieldsData { - unskipped_field_names, - .. - } = process_fields(&s.fields)?; - let field_count = unskipped_field_names.len(); + let fields_data = process_fields(&s.fields)?; + let unskipped_field_count = fields_data.unskipped_field_count(); quote! { impl #impl_generics sbor::Categorize <#sbor_cvk> for #ident #ty_generics #where_clause { #[inline] @@ -67,7 +64,7 @@ fn handle_normal_categorize( impl #impl_generics sbor::SborTuple <#sbor_cvk> for #ident #ty_generics #where_clause { fn get_length(&self) -> usize { - #field_count + #unskipped_field_count } } } @@ -80,28 +77,43 @@ fn handle_normal_categorize( .iter() .map(|source_variant| { match source_variant { - SourceVariantData::Reachable(VariantData { source_variant, discriminator, fields_data, .. }) => { - let v_id = &source_variant.ident; - let FieldsData { - unskipped_field_count, - empty_fields_unpacking, - .. - } = &fields_data; + SourceVariantData::Reachable(VariantData { + variant_name, + discriminator, + fields_handling: FieldsHandling::Standard(fields_data), + .. + }) => { + let unskipped_field_count = fields_data.unskipped_field_count(); + let empty_fields_unpacking = fields_data.empty_fields_unpacking(); ( - quote! { Self::#v_id #empty_fields_unpacking => #discriminator, }, - quote! { Self::#v_id #empty_fields_unpacking => #unskipped_field_count, }, + quote! { Self::#variant_name #empty_fields_unpacking => #discriminator, }, + quote! { Self::#variant_name #empty_fields_unpacking => #unskipped_field_count, }, ) }, - SourceVariantData::Unreachable(UnreachableVariantData { source_variant, fields_data, ..}) => { - let v_id = &source_variant.ident; - let FieldsData { - empty_fields_unpacking, - .. - } = &fields_data; - let panic_message = format!("Variant {} ignored as unreachable", v_id.to_string()); + SourceVariantData::Reachable(VariantData { + variant_name, + discriminator, + fields_handling: FieldsHandling::Flatten { + unique_field, + fields_data, + }, + .. + }) => { + let empty_fields_unpacking = fields_data.empty_fields_unpacking(); + let fields_unpacking = fields_data.fields_unpacking(); + let unskipped_field_type = unique_field.field_type(); + let unpacking_variable_name = unique_field.variable_name_from_unpacking(); ( - quote! { Self::#v_id #empty_fields_unpacking => panic!(#panic_message), }, - quote! { Self::#v_id #empty_fields_unpacking => panic!(#panic_message), }, + quote! { Self::#variant_name #empty_fields_unpacking => #discriminator, }, + quote! { Self::#variant_name #fields_unpacking => <#unskipped_field_type as SborTuple<#sbor_cvk>>::get_length(#unpacking_variable_name), }, + ) + }, + SourceVariantData::Unreachable(UnreachableVariantData { variant_name, fields_data, ..}) => { + let empty_fields_unpacking = fields_data.empty_fields_unpacking(); + let panic_message = format!("Variant {} ignored as unreachable", variant_name.to_string()); + ( + quote! { Self::#variant_name #empty_fields_unpacking => panic!(#panic_message), }, + quote! { Self::#variant_name #empty_fields_unpacking => panic!(#panic_message), }, ) }, } @@ -164,22 +176,18 @@ fn handle_transparent_categorize( let DeriveInput { data, .. } = &parsed; match data { Data::Struct(s) => { - let FieldsData { - unskipped_field_names, - unskipped_field_types, - .. - } = process_fields(&s.fields)?; - if unskipped_field_types.len() != 1 { - return Err(Error::new(Span::call_site(), "The transparent attribute is only supported for structs with a single unskipped field.")); - } - let field_type = &unskipped_field_types[0]; - let field_name = &unskipped_field_names[0]; + let single_field = process_fields(&s.fields)? + .unique_unskipped_field() + .map_err(|()| Error::new( + Span::call_site(), + "The transparent attribute is only supported for structs with a single unskipped field.", + ))?; handle_categorize_as( parsed, context_custom_value_kind, - field_type, - "e! { &self.#field_name } + single_field.field_type(), + &single_field.self_field_reference(), ) } Data::Enum(_) => { @@ -458,9 +466,9 @@ mod tests { fn get_length(&self) -> usize { match self { - Self::A => 0, - Self::B(_) => 1, - Self::C { .. } => 1, + Self::A => 0usize, + Self::B(_) => 1usize, + Self::C { .. } => 1usize, } } } diff --git a/sbor-derive-common/src/decode.rs b/sbor-derive-common/src/decode.rs index 50f8a683427..9fc39d8abc8 100644 --- a/sbor-derive-common/src/decode.rs +++ b/sbor-derive-common/src/decode.rs @@ -46,55 +46,23 @@ pub fn handle_transparent_decode( match data { Data::Struct(s) => { - let FieldsData { - unskipped_field_names, - unskipped_field_types, - skipped_field_names, - skipped_field_types, - .. - } = process_fields(&s.fields)?; - if unskipped_field_names.len() != 1 { - return Err(Error::new(Span::call_site(), "The transparent attribute is only supported for structs with a single unskipped field.")); - } - let field_name = &unskipped_field_names[0]; - let field_type = &unskipped_field_types[0]; - - let decode_content = match &s.fields { - syn::Fields::Named(_) => { - quote! { - Self { - #field_name: value, - #(#skipped_field_names: <#skipped_field_types>::default(),)* - } - } - } - syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => { - let mut field_values = Vec::::new(); - for field in unnamed { - if is_skipped(field)? { - let field_type = &field.ty; - field_values.push(parse_quote! {<#field_type>::default()}) - } else { - field_values.push(parse_quote! {value}) - } - } - quote! { - Self( - #(#field_values,)* - ) - } - } - syn::Fields::Unit => { - quote! { - Self {} - } - } - }; + let fields_data = process_fields(&s.fields)?; + let single_field = fields_data + .unique_unskipped_field() + .map_err(|()| Error::new( + Span::call_site(), + "The transparent attribute is only supported for structs with a single unskipped field.", + ))?; + + let decode_content = decode_unique_unskipped_field_from_value( + quote!{ Self }, + &fields_data, + )?; handle_decode_as( parsed, context_custom_value_kind, - field_type, + single_field.field_type(), &decode_content, ) } @@ -152,7 +120,8 @@ pub fn handle_normal_decode( let output = match data { Data::Struct(s) => { - let decode_fields_content = decode_fields_content(quote! { Self }, &s.fields)?; + let fields_data = process_fields(&s.fields)?; + let decode_fields_content = decode_fields_content(quote! { Self }, &fields_data)?; quote! { impl #impl_generics sbor::Decode <#custom_value_kind_generic, #decoder_generic> for #ident #ty_generics #where_clause { @@ -169,19 +138,39 @@ pub fn handle_normal_decode( let EnumVariantsData { sbor_variants, .. } = process_enum_variants(&attrs, &variants)?; let match_arms = sbor_variants .iter() - .map( - |VariantData { - source_variant, - discriminator_pattern, - .. - }| - -> Result<_> { - let v_id = &source_variant.ident; - let decode_fields_content = - decode_fields_content(quote! { Self::#v_id }, &source_variant.fields)?; + .map(|VariantData { + variant_name, + discriminator_pattern, + fields_handling, + .. + }| -> Result<_> { + let content = match fields_handling { + FieldsHandling::Standard(fields_data) => { + decode_fields_content( + quote! { Self::#variant_name }, + fields_data, + )? + }, + FieldsHandling::Flatten { unique_field, fields_data } => { + let field_type = unique_field.field_type(); + let construct_variant = decode_unique_unskipped_field_from_value( + quote! { Self::#variant_name }, + fields_data, + )?; + let tuple_assertion = output_flatten_type_is_sbor_tuple_assertion( + &custom_value_kind_generic, + field_type, + ); + quote! { + #tuple_assertion + let value = <#field_type as sbor::Decode<#custom_value_kind_generic, #decoder_generic>>::decode_body_with_value_kind(decoder, ValueKind::Tuple)?; + Ok(#construct_variant) + } + }, + }; Ok(quote! { #discriminator_pattern => { - #decode_fields_content + #content } }) }, @@ -214,48 +203,57 @@ pub fn handle_normal_decode( Ok(output) } -pub fn decode_fields_content( +fn decode_fields_content( self_constructor: TokenStream, - fields: &syn::Fields, + fields_data: &FieldsData, ) -> Result { - let FieldsData { - unskipped_field_names, - unskipped_field_types, - skipped_field_names, - skipped_field_types, - unskipped_field_count, - .. - } = process_fields(fields)?; - - Ok(match fields { - syn::Fields::Named(_) => { + let unskipped_field_count = fields_data.unskipped_field_count(); + + Ok(match fields_data { + FieldsData::Named(fields) => { + let assignments = fields.iter().map( + |NamedField { + name, + field_type, + is_skipped, + }| { + if *is_skipped { + quote! { #name: <#field_type>::default() } + } else { + quote! { #name: decoder.decode::<#field_type>()? } + } + }, + ); quote! { decoder.read_and_check_size(#unskipped_field_count)?; Ok(#self_constructor { - #(#unskipped_field_names: decoder.decode::<#unskipped_field_types>()?,)* - #(#skipped_field_names: <#skipped_field_types>::default(),)* + #(#assignments,)* }) } } - syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => { - let mut fields = Vec::::new(); - for f in unnamed { - let ty = &f.ty; - if is_skipped(f)? { - fields.push(parse_quote! {<#ty>::default()}) - } else { - fields.push(parse_quote! {decoder.decode::<#ty>()?}) - } - } + FieldsData::Unnamed(fields) => { + let values = fields.iter().map( + |UnnamedField { + field_type, + is_skipped, + .. + }| { + if *is_skipped { + quote! { <#field_type>::default() } + } else { + quote! { decoder.decode::<#field_type>()? } + } + }, + ); quote! { decoder.read_and_check_size(#unskipped_field_count)?; Ok(#self_constructor ( - #(#fields,)* + #(#values,)* )) } } - syn::Fields::Unit => { + FieldsData::Unit => { quote! { decoder.read_and_check_size(#unskipped_field_count)?; Ok(#self_constructor) @@ -264,6 +262,67 @@ pub fn decode_fields_content( }) } +fn decode_unique_unskipped_field_from_value( + self_constructor: TokenStream, + fields_data: &FieldsData, +) -> Result { + if fields_data.unique_unskipped_field().is_err() { + panic!("Should already have checked that there is only one unique unskipped field before calling this method"); + } + + let output = match &fields_data { + FieldsData::Named(fields) => { + let assignments = fields.iter().map( + |NamedField { + name, + field_type, + is_skipped, + }| { + if *is_skipped { + quote! { #name: <#field_type>::default() } + } else { + // Have already checked there's only one of these + quote! { #name: value } + } + }, + ); + quote! { + #self_constructor { + #(#assignments,)* + } + } + } + FieldsData::Unnamed(fields) => { + let field_values = fields.iter().map( + |UnnamedField { + field_type, + is_skipped, + .. + }| { + if *is_skipped { + quote! { <#field_type>::default() } + } else { + // Have already checked there's only one of these + quote! { value } + } + }, + ); + quote! { + #self_constructor( + #(#field_values,)* + ) + } + } + FieldsData::Unit => { + quote! { + #self_constructor + } + } + }; + + Ok(output) +} + #[cfg(test)] mod tests { use proc_macro2::TokenStream; @@ -288,7 +347,7 @@ mod tests { fn decode_body_with_value_kind(decoder: &mut D, value_kind: sbor::ValueKind) -> Result { use sbor::{self, Decode}; decoder.check_preloaded_value_kind(value_kind, sbor::ValueKind::Tuple)?; - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self { a: decoder.decode::()?, }) @@ -315,7 +374,7 @@ mod tests { fn decode_body_with_value_kind(decoder: &mut D0, value_kind: sbor::ValueKind) -> Result { use sbor::{self, Decode}; decoder.check_preloaded_value_kind(value_kind, sbor::ValueKind::Tuple)?; - decoder.read_and_check_size(2)?; + decoder.read_and_check_size(2usize)?; Ok(Self { a: decoder.decode::()?, b: decoder.decode::()?, @@ -342,7 +401,7 @@ mod tests { fn decode_body_with_value_kind(decoder: &mut D, value_kind: sbor::ValueKind) -> Result { use sbor::{self, Decode}; decoder.check_preloaded_value_kind(value_kind, sbor::ValueKind::Tuple)?; - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self { a: decoder.decode::()?, }) @@ -372,7 +431,7 @@ mod tests { fn decode_body_with_value_kind(decoder: &mut D, value_kind: sbor::ValueKind) -> Result { use sbor::{self, Decode}; decoder.check_preloaded_value_kind(value_kind, sbor::ValueKind::Tuple)?; - decoder.read_and_check_size(4)?; + decoder.read_and_check_size(4usize)?; Ok(Self { a: decoder.decode::<&'a u32>()?, b: decoder.decode::()?, @@ -402,15 +461,15 @@ mod tests { #[deny(unreachable_patterns)] match discriminator { 0u8 => { - decoder.read_and_check_size(0)?; + decoder.read_and_check_size(0usize)?; Ok(Self::A) }, 1u8 => { - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self::B(decoder.decode::()?,)) }, 2u8 => { - decoder.read_and_check_size(1)?; + decoder.read_and_check_size(1usize)?; Ok(Self::C { x: decoder.decode::()?, }) @@ -436,7 +495,7 @@ mod tests { fn decode_body_with_value_kind(decoder: &mut D, value_kind: sbor::ValueKind) -> Result { use sbor::{self, Decode}; decoder.check_preloaded_value_kind(value_kind, sbor::ValueKind::Tuple)?; - decoder.read_and_check_size(0)?; + decoder.read_and_check_size(0usize)?; Ok(Self { a: ::default(), }) diff --git a/sbor-derive-common/src/describe.rs b/sbor-derive-common/src/describe.rs index 8a9e37ebddc..0c5c75e3b34 100644 --- a/sbor-derive-common/src/describe.rs +++ b/sbor-derive-common/src/describe.rs @@ -1,3 +1,4 @@ +use itertools::Itertools as _; use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::*; @@ -48,21 +49,17 @@ fn handle_transparent_describe( let DeriveInput { data, .. } = &parsed; match &data { Data::Struct(s) => { - let FieldsData { - unskipped_field_types, - .. - } = process_fields(&s.fields)?; - - if unskipped_field_types.len() != 1 { - return Err(Error::new(Span::call_site(), "The transparent attribute is only supported for structs with a single unskipped field.")); - } - - let field_type = &unskipped_field_types[0]; + let single_field = process_fields(&s.fields)? + .unique_unskipped_field() + .map_err(|()| Error::new( + Span::call_site(), + "The transparent attribute is only supported for structs with a single unskipped field.", + ))?; handle_describe_as( parsed, context_custom_type_kind, - field_type, + single_field.field_type(), code_hash, ) } @@ -148,6 +145,12 @@ fn handle_describe_as( Ok(output) } +#[derive(PartialEq, Eq, Hash)] +enum TypeDependency { + Child(Type), + DescendentsOnly(Type), +} + fn handle_normal_describe( parsed: DeriveInput, code_hash: TokenStream, @@ -183,143 +186,79 @@ fn handle_normal_describe( ) }; - let output = match data { - Data::Struct(s) => match &s.fields { - syn::Fields::Named(FieldsNamed { .. }) => { - let FieldsData { - unskipped_field_types, - unskipped_field_name_strings, - .. - } = process_fields(&s.fields)?; - let unique_field_types: Vec<_> = get_unique_types(&unskipped_field_types); - quote! { - impl #impl_generics sbor::Describe <#custom_type_kind_generic> for #ident #ty_generics #where_clause { - const TYPE_ID: sbor::RustTypeId = #type_id; - - fn type_data() -> sbor::TypeData<#custom_type_kind_generic, sbor::RustTypeId> { - sbor::TypeData::struct_with_named_fields( - #type_name, - sbor::rust::vec![ - #((#unskipped_field_name_strings, <#unskipped_field_types as sbor::Describe<#custom_type_kind_generic>>::TYPE_ID),)* - ], - ) - } + let mut type_dependencies = vec![]; - fn add_all_dependencies(aggregator: &mut sbor::TypeAggregator<#custom_type_kind_generic>) { - #(aggregator.add_child_type_and_descendents::<#unique_field_types>();)* - } - } - } - } - syn::Fields::Unnamed(FieldsUnnamed { .. }) => { - let FieldsData { - unskipped_field_types, - .. - } = process_fields(&s.fields)?; - let unique_field_types: Vec<_> = get_unique_types(&unskipped_field_types); - - quote! { - impl #impl_generics sbor::Describe <#custom_type_kind_generic> for #ident #ty_generics #where_clause { - const TYPE_ID: sbor::RustTypeId = #type_id; - - fn type_data() -> sbor::TypeData<#custom_type_kind_generic, sbor::RustTypeId> { - sbor::TypeData::struct_with_unnamed_fields( - #type_name, - sbor::rust::vec![ - #(<#unskipped_field_types as sbor::Describe<#custom_type_kind_generic>>::TYPE_ID,)* - ], - ) - } + let type_data = match data { + Data::Struct(s) => { + let fields_data = process_fields(&s.fields)?; - fn add_all_dependencies(aggregator: &mut sbor::TypeAggregator<#custom_type_kind_generic>) { - #(aggregator.add_child_type_and_descendents::<#unique_field_types>();)* - } - } - } + for field_type in fields_data.unskipped_field_types() { + type_dependencies.push(TypeDependency::Child(field_type)); } - syn::Fields::Unit => { - quote! { - impl #impl_generics sbor::Describe <#custom_type_kind_generic> for #ident #ty_generics #where_clause { - const TYPE_ID: sbor::RustTypeId = #type_id; - fn type_data() -> sbor::TypeData<#custom_type_kind_generic, sbor::RustTypeId> { - sbor::TypeData::struct_with_unit_fields(#type_name) - } - } - } - } - }, + fields_data_to_type_data(&custom_type_kind_generic, &type_name, &fields_data)? + } Data::Enum(DataEnum { variants, .. }) => { let EnumVariantsData { sbor_variants, .. } = process_enum_variants(&attrs, &variants)?; - let mut all_field_types = Vec::new(); - let match_arms = sbor_variants .iter() - .map(|VariantData { discriminator, source_variant, fields_data, .. }| { - let variant_name_str = source_variant.ident.to_string(); - - let FieldsData { - unskipped_field_types, - unskipped_field_name_strings, - .. - } = fields_data; - - all_field_types.extend_from_slice(&unskipped_field_types); + .map(|VariantData { variant_name, fields_handling, discriminator, .. }| { + let variant_name_str = ident_to_lit_str(variant_name); - let variant_type_data = match &source_variant.fields { - Fields::Named(FieldsNamed { .. }) => { - quote! { - sbor::TypeData::struct_with_named_fields( - #variant_name_str, - sbor::rust::vec![ - #((#unskipped_field_name_strings, <#unskipped_field_types as sbor::Describe<#custom_type_kind_generic>>::TYPE_ID),)* - ], - ) - } - } - Fields::Unnamed(FieldsUnnamed { .. }) => { - quote! { - sbor::TypeData::struct_with_unnamed_fields( - #variant_name_str, - sbor::rust::vec![ - #(<#unskipped_field_types as sbor::Describe<#custom_type_kind_generic>>::TYPE_ID,)* - ], - ) + let variant_type_data = match fields_handling { + FieldsHandling::Standard(fields_data) => { + for field_type in fields_data.unskipped_field_types() { + type_dependencies.push(TypeDependency::Child(field_type)); } + fields_data_to_type_data( + &custom_type_kind_generic, + &variant_name_str, + &fields_data, + )? } - Fields::Unit => { + FieldsHandling::Flatten { unique_field, .. } => { + let flattened_type = unique_field.field_type(); + + // We need to include the flattened type's descendents, + // but not the flatenned type itself (we're taking its type details) + // and mutating them effectively + type_dependencies.push(TypeDependency::DescendentsOnly(flattened_type.clone())); + quote! { - sbor::TypeData::struct_with_unit_fields(#variant_name_str) + let mut flattened_type_data = <#flattened_type as sbor::Describe<#custom_type_kind_generic>>::type_data(); + // Flatten is only valid if the child type is a Tuple, so we must + // double-check that constraint. + // We can't do an SborTuple assertion here because we don't have a + // specific X = custom value kind to check on (e.g. SborSchema is shared + // by both ManifestSbor and ScryptoSbor). This assertion will almost certainly + // be checked by an Encode/Decode/Categorize implementation, but on the off-chance + // if isn't, let's check the type kind is correct in the type itself at + // Describe run time. + let sbor::schema::TypeKind::Tuple { .. } = &flattened_type_data.kind else { + panic!("The flatten attribute cannot be used with a non-tuple child"); + }; + // We rename the tuple type to be the enum variant name, + // and this becomes the enum variant type data + flattened_type_data.with_name(Some(sbor::rust::prelude::Cow::Borrowed(#variant_name_str))) } } }; - Ok(Some(quote! { - #discriminator => #variant_type_data, - })) + + Ok(quote! { + #discriminator => { #variant_type_data }, + }) }) .collect::>>()?; - let unique_field_types = get_unique_types(&all_field_types); - quote! { - impl #impl_generics sbor::Describe <#custom_type_kind_generic> for #ident #ty_generics #where_clause { - const TYPE_ID: sbor::RustTypeId = #type_id; - - fn type_data() -> sbor::TypeData<#custom_type_kind_generic, sbor::RustTypeId> { - use sbor::rust::borrow::ToOwned; - sbor::TypeData::enum_variants( - #type_name, - sbor::rust::prelude::indexmap![ - #(#match_arms)* - ], - ) - } - - fn add_all_dependencies(aggregator: &mut sbor::TypeAggregator<#custom_type_kind_generic>) { - #(aggregator.add_child_type_and_descendents::<#unique_field_types>();)* - } - } + use sbor::rust::borrow::ToOwned; + sbor::TypeData::enum_variants( + #type_name, + sbor::rust::prelude::indexmap![ + #(#match_arms)* + ], + ) } } Data::Union(_) => { @@ -327,7 +266,70 @@ fn handle_normal_describe( } }; - Ok(output) + let dependencies = + type_dependencies + .iter() + .unique() + .map(|type_dependency| match type_dependency { + TypeDependency::Child(child_type) => quote! { + aggregator.add_child_type_and_descendents::<#child_type>(); + }, + TypeDependency::DescendentsOnly(descendent_only_type) => quote! { + aggregator.add_schema_descendents::<#descendent_only_type>(); + }, + }); + + Ok(quote! { + impl #impl_generics sbor::Describe <#custom_type_kind_generic> for #ident #ty_generics #where_clause { + const TYPE_ID: sbor::RustTypeId = #type_id; + + fn type_data() -> sbor::TypeData<#custom_type_kind_generic, sbor::RustTypeId> { + #type_data + } + + fn add_all_dependencies(aggregator: &mut sbor::TypeAggregator<#custom_type_kind_generic>) { + #(#dependencies)* + } + } + }) +} + +fn fields_data_to_type_data( + custom_type_kind_generic: &Path, + type_name: &LitStr, + fields_data: &FieldsData, +) -> Result { + let unskipped_field_types = fields_data.unskipped_field_types(); + + let type_data = match fields_data { + FieldsData::Named(fields) => { + let unskipped_field_name_strings = fields.unskipped_field_name_strings(); + quote! { + sbor::TypeData::struct_with_named_fields( + #type_name, + sbor::rust::vec![ + #((#unskipped_field_name_strings, <#unskipped_field_types as sbor::Describe<#custom_type_kind_generic>>::TYPE_ID),)* + ], + ) + } + } + FieldsData::Unnamed(_) => { + quote! { + sbor::TypeData::struct_with_unnamed_fields( + #type_name, + sbor::rust::vec![ + #(<#unskipped_field_types as sbor::Describe<#custom_type_kind_generic>>::TYPE_ID,)* + ], + ) + } + } + FieldsData::Unit => { + quote! { + sbor::TypeData::struct_with_unit_fields(#type_name) + } + } + }; + Ok(type_data) } pub fn validate_type_name(type_name: &LitStr) -> Result<()> { @@ -533,6 +535,8 @@ mod tests { fn type_data() -> sbor::TypeData { sbor::TypeData::struct_with_unit_fields("Test") } + + fn add_all_dependencies(aggregator: &mut sbor::TypeAggregator) { } } }, ); @@ -564,20 +568,20 @@ mod tests { sbor::TypeData::enum_variants( "Test", sbor::rust::prelude::indexmap![ - 0u8 => sbor::TypeData::struct_with_unit_fields("A"), - 1u8 => sbor::TypeData::struct_with_unnamed_fields( + 0u8 => { sbor::TypeData::struct_with_unit_fields("A") }, + 1u8 => { sbor::TypeData::struct_with_unnamed_fields( "B", sbor::rust::vec![ >::TYPE_ID, as sbor::Describe>::TYPE_ID, ], - ), - 2u8 => sbor::TypeData::struct_with_named_fields( + ) }, + 2u8 => { sbor::TypeData::struct_with_named_fields( "C", sbor::rust::vec![ ("x", <[u8; 5] as sbor::Describe>::TYPE_ID), ], - ), + ) }, ], ) } diff --git a/sbor-derive-common/src/encode.rs b/sbor-derive-common/src/encode.rs index 1e8bf679693..077d73c316e 100644 --- a/sbor-derive-common/src/encode.rs +++ b/sbor-derive-common/src/encode.rs @@ -42,21 +42,17 @@ pub fn handle_transparent_encode( ) -> Result { let output = match &parsed.data { Data::Struct(s) => { - let FieldsData { - unskipped_field_types, - unskipped_field_names, - .. - } = process_fields(&s.fields)?; - if unskipped_field_types.len() != 1 { - return Err(Error::new(Span::call_site(), "The transparent attribute is only supported for structs with a single unskipped field.")); - } - let field_type = &unskipped_field_types[0]; - let field_name = &unskipped_field_names[0]; + let single_field = process_fields(&s.fields)? + .unique_unskipped_field() + .map_err(|()| Error::new( + Span::call_site(), + "The transparent attribute is only supported for structs with a single unskipped field.", + ))?; handle_encode_as( parsed, context_custom_value_kind, - &field_type, - "e! { &self.#field_name }, + single_field.field_type(), + &single_field.self_field_reference(), )? } Data::Enum(_) => { @@ -125,11 +121,9 @@ pub fn handle_normal_encode( let output = match data { Data::Struct(s) => { - let FieldsData { - unskipped_field_names, - unskipped_field_count, - .. - } = process_fields(&s.fields)?; + let fields_data = process_fields(&s.fields)?; + let unskipped_field_count = fields_data.unskipped_field_count(); + let unskipped_self_field_references = fields_data.unskipped_self_field_references(); quote! { impl #impl_generics sbor::Encode <#custom_value_kind_generic, #encoder_generic> for #ident #ty_generics #where_clause { #[inline] @@ -141,7 +135,7 @@ pub fn handle_normal_encode( fn encode_body(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; encoder.write_size(#unskipped_field_count)?; - #(encoder.encode(&self.#unskipped_field_names)?;)* + #(encoder.encode(#unskipped_self_field_references)?;)* Ok(()) } } @@ -156,40 +150,62 @@ pub fn handle_normal_encode( .map(|source_variant| { Ok(match source_variant { SourceVariantData::Reachable(VariantData { - source_variant, + variant_name, discriminator, - fields_data, + fields_handling: FieldsHandling::Standard(fields_data), .. }) => { - let v_id = &source_variant.ident; - let FieldsData { - unskipped_field_count, - fields_unpacking, - unskipped_unpacked_field_names, - .. - } = fields_data; + let unskipped_field_count = fields_data.unskipped_field_count(); + let fields_unpacking = fields_data.fields_unpacking(); + let unskipped_unpacking_variable_names = fields_data.unskipped_unpacking_variable_names(); quote! { - Self::#v_id #fields_unpacking => { + Self::#variant_name #fields_unpacking => { encoder.write_discriminator(#discriminator)?; encoder.write_size(#unskipped_field_count)?; - #(encoder.encode(#unskipped_unpacked_field_names)?;)* + #(encoder.encode(#unskipped_unpacking_variable_names)?;)* + } + } + } + SourceVariantData::Reachable(VariantData { + variant_name, + discriminator, + fields_handling: FieldsHandling::Flatten { unique_field, fields_data, }, + .. + }) => { + let fields_unpacking = fields_data.fields_unpacking(); + let field_type = unique_field.field_type(); + let unpacking_field_name = unique_field.variable_name_from_unpacking(); + let tuple_assertion = output_flatten_type_is_sbor_tuple_assertion( + &custom_value_kind_generic, + field_type, + ); + quote! { + Self::#variant_name #fields_unpacking => { + // Flatten is only valid if the single child type is an SBOR tuple, so do a + // zero-cost assertion on this so the user gets a good error message if they + // misuse this. + #tuple_assertion + // We make use of the fact that an enum body encodes as (discriminator, fields_count, ..fields) + // And a tuple body encodes as (fields_count, ..fields) + // So we can flatten by encoding the discriminator and then running `encode_body` on the child tuple + encoder.write_discriminator(#discriminator)?; + <#field_type as sbor::Encode <#custom_value_kind_generic, #encoder_generic>>::encode_body( + #unpacking_field_name, + encoder + )?; } } } SourceVariantData::Unreachable(UnreachableVariantData { - source_variant, + variant_name, fields_data, .. }) => { - let v_id = &source_variant.ident; - let FieldsData { - empty_fields_unpacking, - .. - } = &fields_data; + let empty_fields_unpacking = fields_data.empty_fields_unpacking(); let panic_message = - format!("Variant {} ignored as unreachable", v_id.to_string()); + format!("Variant {} ignored as unreachable", variant_name.to_string()); quote! { - Self::#v_id #empty_fields_unpacking => panic!(#panic_message), + Self::#variant_name #empty_fields_unpacking => panic!(#panic_message), } } }) @@ -262,7 +278,7 @@ mod tests { #[inline] fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; - encoder.write_size(1)?; + encoder.write_size(1usize)?; encoder.encode(&self.a)?; Ok(()) } @@ -291,16 +307,16 @@ mod tests { match self { Self::A => { encoder.write_discriminator(0u8)?; - encoder.write_size(0)?; + encoder.write_size(0usize)?; } Self::B(a0) => { encoder.write_discriminator(1u8)?; - encoder.write_size(1)?; + encoder.write_size(1usize)?; encoder.encode(a0)?; } Self::C { x, .. } => { encoder.write_discriminator(2u8)?; - encoder.write_size(1)?; + encoder.write_size(1usize)?; encoder.encode(x)?; } } @@ -328,7 +344,7 @@ mod tests { #[inline] fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; - encoder.write_size(0)?; + encoder.write_size(0usize)?; Ok(()) } } @@ -357,7 +373,7 @@ mod tests { #[inline] fn encode_body(&self, encoder: &mut E0) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; - encoder.write_size(2)?; + encoder.write_size(2usize)?; encoder.encode(&self.a)?; encoder.encode(&self.b)?; Ok(()) @@ -387,7 +403,7 @@ mod tests { #[inline] fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; - encoder.write_size(0)?; + encoder.write_size(0usize)?; Ok(()) } } @@ -415,7 +431,7 @@ mod tests { #[inline] fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> { use sbor::{self, Encode}; - encoder.write_size(0)?; + encoder.write_size(0usize)?; Ok(()) } } diff --git a/sbor-derive-common/src/utils.rs b/sbor-derive-common/src/utils.rs index c59bd6187ae..16a1626e474 100644 --- a/sbor-derive-common/src/utils.rs +++ b/sbor-derive-common/src/utils.rs @@ -269,16 +269,25 @@ pub(crate) enum SourceVariantData { #[derive(Clone)] pub(crate) struct UnreachableVariantData { - pub source_variant: Variant, + pub variant_name: Ident, pub fields_data: FieldsData, } #[derive(Clone)] pub(crate) struct VariantData { - pub source_variant: Variant, + pub variant_name: Ident, pub discriminator: Expr, pub discriminator_pattern: Pat, - pub fields_data: FieldsData, + pub fields_handling: FieldsHandling, +} + +#[derive(Clone)] +pub(crate) enum FieldsHandling { + Standard(FieldsData), + Flatten { + unique_field: SingleField, + fields_data: FieldsData, + }, } pub(crate) struct EnumVariantsData { @@ -310,13 +319,31 @@ pub(crate) fn process_enum_variants( .map(|(i, variant)| -> Result<_> { let mut variant_attributes = extract_wrapped_typed_attributes(&variant.attrs, "sbor")?; let fields_data = process_fields(&variant.fields)?; - let source_variant = variant.clone(); + let variant_name = variant.ident.clone(); if let Some(_) = variant_attributes.remove("unreachable") { return Ok(SourceVariantData::Unreachable(UnreachableVariantData { + variant_name, fields_data, - source_variant, })); } + let fields_handling = match variant_attributes.remove("flatten") { + Some(AttributeValue::None(span)) => { + let unique_field = fields_data.unique_unskipped_field() + .map_err(|()| Error::new( + span, + "The flatten attribute can only be used by an enum variant with exactly one unskipped field, and that child must be an SBOR tuple.", + ))?; + FieldsHandling::Flatten { + unique_field, + fields_data, + } + } + Some(unsupported_value) => return Err(Error::new( + unsupported_value.span(), + "The flatten attribute can't have any associated value.", + )), + None => FieldsHandling::Standard(fields_data), + }; reachable_variants_count += 1; let discriminator = resolve_discriminator(use_repr_discriminators, i, variant, variant_attributes)?; @@ -326,10 +353,10 @@ pub(crate) fn process_enum_variants( explicit_discriminators_count += 1; }; Ok(SourceVariantData::Reachable(VariantData { + variant_name, discriminator: discriminator.expression, discriminator_pattern: discriminator.pattern, - source_variant, - fields_data, + fields_handling, })) }) .collect::>()?; @@ -487,6 +514,10 @@ fn parse_pattern_from_literal(literal: &Lit) -> Option { } } +pub fn ident_to_lit_str(ident: &Ident) -> LitStr { + LitStr::new(&ident.to_string(), ident.span()) +} + pub fn get_sbor_attribute_string_value( attributes: &[Attribute], attribute_name: &str, @@ -600,6 +631,18 @@ pub fn get_generic_types(generics: &Generics) -> Vec { .collect() } +pub fn output_flatten_type_is_sbor_tuple_assertion( + custom_value_kind_generic: &Path, + type_to_assert: &Type, +) -> TokenStream { + // We give it a specific name because it's this name that appears on error messages + // if the assertion fails + quote! { + fn assert_flattened_type_is_sbor_tuple, XToAssertWith: sbor::CustomValueKind>() {} + assert_flattened_type_is_sbor_tuple::<#type_to_assert, #custom_value_kind_generic>(); + } +} + pub fn parse_str_with_span(source_string: &str, span: Span) -> Result { // https://github.com/dtolnay/syn/issues/559 LitStr::new(source_string, span).parse() @@ -732,123 +775,271 @@ pub fn get_hash_of_code(input: &TokenStream) -> [u8; 20] { const_sha1::sha1(input.to_string().as_bytes()).as_bytes() } -pub fn get_unique_types<'a>(types: &[syn::Type]) -> Vec { - types.iter().unique().cloned().collect() -} - #[derive(Clone)] -pub(crate) struct FieldsData { - pub unskipped_field_names: Vec, - pub unskipped_field_name_strings: Vec, - pub unskipped_field_types: Vec, - pub skipped_field_names: Vec, - pub skipped_field_types: Vec, - pub fields_unpacking: TokenStream, - pub empty_fields_unpacking: TokenStream, - pub unskipped_unpacked_field_names: Vec, - pub unskipped_field_count: Index, +pub(crate) enum FieldsData { + Named(NamedFieldsData), + Unnamed(UnnamedFieldsData), + Unit, } -pub(crate) fn process_fields(fields: &syn::Fields) -> Result { - Ok(match fields { - Fields::Named(fields) => { - let mut unskipped_field_names = Vec::new(); - let mut unskipped_field_name_strings = Vec::new(); - let mut unskipped_field_types = Vec::new(); - let mut skipped_field_names = Vec::new(); - let mut skipped_field_types = Vec::new(); - for f in fields.named.iter() { - let ident = &f.ident; - if !is_skipped(f)? { - unskipped_field_names.push(quote! { #ident }); - unskipped_field_name_strings - .push(ident.as_ref().map(|i| i.to_string()).unwrap_or_default()); - unskipped_field_types.push(f.ty.clone()); +impl FieldsData { + pub fn unskipped(&self) -> Box> + '_> { + match self { + Self::Named(fields) => Box::new(fields.unskipped().map(NamedField::as_box_dyn)), + Self::Unnamed(fields) => Box::new(fields.unskipped().map(UnnamedField::as_box_dyn)), + Self::Unit => Box::new(std::iter::empty()), + } + } + + pub fn unique_unskipped_field(&self) -> core::result::Result { + match self { + Self::Named(fields) => { + if let Some((field,)) = fields.unskipped().collect_tuple() { + Ok(SingleField::NamedField(field.clone())) } else { - skipped_field_names.push(quote! { #ident }); - skipped_field_types.push(f.ty.clone()); + Err(()) } } + Self::Unnamed(fields) => { + if let Some((field,)) = fields.unskipped().collect_tuple() { + Ok(SingleField::UnnamedField(field.clone())) + } else { + Err(()) + } + } + Self::Unit => Err(()), + } + } - let fields_unpacking = quote! { - {#(#unskipped_field_names,)* ..} - }; - let empty_fields_unpacking = quote! { + pub fn unskipped_self_field_references(&self) -> Vec { + self.unskipped().map(|f| f.self_field_reference()).collect() + } + + pub fn unskipped_field_types(&self) -> Vec { + self.unskipped().map(|f| f.field_type().clone()).collect() + } + + pub fn unskipped_field_count(&self) -> usize { + self.unskipped().count() + } + + pub fn empty_fields_unpacking(&self) -> TokenStream { + match self { + Self::Named(_) => quote! { { .. } - }; - let unskipped_unpacked_field_names = unskipped_field_names.clone(); - - let unskipped_field_count = Index::from(unskipped_field_names.len()); - - FieldsData { - unskipped_field_names, - unskipped_field_name_strings, - unskipped_field_types, - skipped_field_names, - skipped_field_types, - fields_unpacking, - empty_fields_unpacking, - unskipped_unpacked_field_names, - unskipped_field_count, + }, + Self::Unnamed(UnnamedFieldsData(fields)) => { + let empty_idents = fields.iter().map(|_| format_ident!("_")); + quote! { + (#(#empty_idents),*) + } } + Self::Unit => quote! {}, } - Fields::Unnamed(fields) => { - let mut unskipped_indices = Vec::new(); - let mut unskipped_field_name_strings = Vec::new(); - let mut unskipped_field_types = Vec::new(); - let mut unskipped_unpacked_field_names = Vec::new(); - let mut skipped_indices = Vec::new(); - let mut skipped_field_types = Vec::new(); - let mut unpacking_idents = Vec::new(); - let mut empty_idents = Vec::new(); - for (i, f) in fields.unnamed.iter().enumerate() { - let index = Index::from(i); - if !is_skipped(f)? { - unskipped_indices.push(quote! { #index }); - unskipped_field_name_strings.push(i.to_string()); - unskipped_field_types.push(f.ty.clone()); - let unpacked_name_ident = format_ident!("a{}", i); - unskipped_unpacked_field_names.push(quote! { #unpacked_name_ident }); - unpacking_idents.push(unpacked_name_ident); - } else { - skipped_indices.push(quote! { #index }); - skipped_field_types.push(f.ty.clone()); - unpacking_idents.push(format_ident!("_")); + } + + pub fn fields_unpacking(&self) -> TokenStream { + match self { + Self::Named(fields) => { + let field_names = fields.unskipped_field_names(); + quote! { + { #(#field_names,)* ..} } - empty_idents.push(format_ident!("_")); } - let fields_unpacking = quote! { - (#(#unpacking_idents),*) - }; - let empty_fields_unpacking = quote! { - (#(#empty_idents),*) - }; - - let unskipped_field_count = Index::from(unskipped_indices.len()); - - FieldsData { - unskipped_field_names: unskipped_indices, - unskipped_field_name_strings, - unskipped_field_types, - skipped_field_names: skipped_indices, - skipped_field_types, - fields_unpacking, - empty_fields_unpacking, - unskipped_unpacked_field_names, - unskipped_field_count, + Self::Unnamed(UnnamedFieldsData(fields)) => { + let variable_names = fields + .iter() + .map(|field| &field.variable_name_from_unpacking); + quote! { + (#(#variable_names),*) + } } + Self::Unit => quote! {}, + } + } + + pub fn unskipped_unpacking_variable_names(&self) -> Vec { + match self { + Self::Named(fields) => fields.unskipped_field_names(), + Self::Unnamed(fields) => fields + .unskipped() + .map(|field| field.variable_name_from_unpacking.clone()) + .collect(), + Self::Unit => vec![], + } + } +} + +pub(crate) trait FieldReference { + fn self_field_reference(&self) -> TokenStream; + fn field_type(&self) -> &Type; + fn variable_name_from_unpacking(&self) -> &Ident; +} + +#[derive(Clone)] +pub(crate) enum SingleField { + NamedField(NamedField), + UnnamedField(UnnamedField), +} + +impl FieldReference for SingleField { + fn self_field_reference(&self) -> TokenStream { + match self { + Self::NamedField(field) => field.self_field_reference(), + Self::UnnamedField(field) => field.self_field_reference(), + } + } + + fn field_type(&self) -> &Type { + match self { + Self::NamedField(field) => field.field_type(), + Self::UnnamedField(field) => field.field_type(), + } + } + + fn variable_name_from_unpacking(&self) -> &Ident { + match self { + Self::NamedField(field) => field.variable_name_from_unpacking(), + Self::UnnamedField(field) => field.variable_name_from_unpacking(), + } + } +} + +#[derive(Clone)] +pub(crate) struct NamedFieldsData(Vec); + +impl NamedFieldsData { + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn unskipped(&self) -> impl Iterator { + self.0.iter().filter(|f| !f.is_skipped) + } + + pub fn unskipped_field_names(&self) -> Vec { + self.unskipped().map(|f| f.name.clone()).collect() + } + + pub fn unskipped_field_name_strings(&self) -> Vec { + self.unskipped().map(|f| f.name.to_string()).collect() + } +} + +#[derive(Clone)] +pub(crate) struct UnnamedFieldsData(Vec); + +impl UnnamedFieldsData { + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn unskipped(&self) -> impl Iterator { + self.0.iter().filter(|f| !f.is_skipped) + } +} + +#[derive(Clone)] +pub(crate) struct NamedField { + pub name: Ident, + pub field_type: Type, + pub is_skipped: bool, +} + +impl NamedField { + fn as_box_dyn(&self) -> Box { + Box::new(self.clone()) + } +} + +impl FieldReference for NamedField { + fn self_field_reference(&self) -> TokenStream { + let name = &self.name; + quote! { &self.#name } + } + + fn field_type(&self) -> &Type { + &self.field_type + } + + fn variable_name_from_unpacking(&self) -> &Ident { + &self.name + } +} + +#[derive(Clone)] +pub(crate) struct UnnamedField { + pub index: Index, + pub variable_name_from_unpacking: Ident, + pub field_type: Type, + pub is_skipped: bool, +} + +impl UnnamedField { + fn as_box_dyn(&self) -> Box { + Box::new(self.clone()) + } +} + +impl FieldReference for UnnamedField { + fn self_field_reference(&self) -> TokenStream { + let index = &self.index; + quote! { &self.#index } + } + + fn field_type(&self) -> &Type { + &self.field_type + } + + fn variable_name_from_unpacking(&self) -> &Ident { + &self.variable_name_from_unpacking + } +} + +pub(crate) fn process_fields(fields: &syn::Fields) -> Result { + Ok(match fields { + Fields::Named(fields) => { + let fields = fields + .named + .iter() + .map(|f| -> Result<_> { + let ident = f.ident.as_ref().unwrap().clone(); + let is_skipped = is_skipped(f)?; + Ok(NamedField { + name: ident, + field_type: f.ty.clone(), + is_skipped, + }) + }) + .collect::>()?; + + FieldsData::Named(NamedFieldsData(fields)) + } + Fields::Unnamed(fields) => { + let fields = fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| -> Result<_> { + let index = Index::from(i); + let is_skipped = is_skipped(f)?; + let unpacked_variable_name = if is_skipped { + format_ident!("_") + } else { + format_ident!("a{}", i) + }; + Ok(UnnamedField { + index, + field_type: f.ty.clone(), + is_skipped, + variable_name_from_unpacking: unpacked_variable_name, + }) + }) + .collect::>()?; + + FieldsData::Unnamed(UnnamedFieldsData(fields)) } - Fields::Unit => FieldsData { - unskipped_field_names: vec![], - unskipped_field_name_strings: vec![], - unskipped_field_types: vec![], - skipped_field_names: vec![], - skipped_field_types: vec![], - fields_unpacking: quote! {}, - empty_fields_unpacking: quote! {}, - unskipped_unpacked_field_names: vec![], - unskipped_field_count: Index::from(0), - }, + Fields::Unit => FieldsData::Unit, }) } diff --git a/sbor-tests/tests/enum.rs b/sbor-tests/tests/enum.rs index 4c414612dc0..f160cb414d9 100644 --- a/sbor-tests/tests/enum.rs +++ b/sbor-tests/tests/enum.rs @@ -75,8 +75,52 @@ pub enum Mixed { I = 0b11011, } +#[derive(Debug, PartialEq, Eq, Sbor)] +enum FlattenEnum { + #[sbor(flatten)] + A { + #[sbor(skip)] + skipped: u8, + y: (u32, MyOtherType), + }, + #[sbor(flatten)] + B(#[sbor(skip)] u8, (u32,)), + #[sbor(flatten)] + C(MyInnerStruct), + D, + E(MyOtherTypeTwo), +} + +#[derive(Debug, PartialEq, Eq, Sbor)] +struct MyInnerStruct { + hello: String, + world: MyInnerInnerType, // This checks that we properly capture descendents in Describe +} + +#[derive(Debug, PartialEq, Eq, Sbor)] +struct MyOtherType(u8); + +#[derive(Debug, PartialEq, Eq, Sbor)] +struct MyInnerInnerType(u8); + +#[derive(Debug, PartialEq, Eq, Sbor)] +struct MyOtherTypeTwo(u8); + +#[derive(Debug, PartialEq, Eq, Sbor)] +#[sbor(type_name = "FlattenEnum")] +enum FlattenedEnum { + A(u32, MyOtherType), + B(u32), + C { + hello: String, + world: MyInnerInnerType, + }, + D, + E(MyOtherTypeTwo), +} + #[test] -fn can_encode_and_decode() { +fn test_encode_decode_and_schemas() { check_encode_decode_schema(&Abc::Variant1); check_encode_decode_schema(&Abc::Variant2); check_encode_decode_schema(&AbcV2::Variant1); @@ -92,8 +136,43 @@ fn can_encode_and_decode() { check_encode_decode_schema(&Mixed::G); check_encode_decode_schema(&Mixed::H); check_encode_decode_schema(&Mixed::I); + check_encode_decode_schema(&FlattenEnum::A { + skipped: 0, + y: (1, MyOtherType(5)), + }); + check_encode_identically( + &FlattenEnum::A { + skipped: 0, + y: (1, MyOtherType(5)), + }, + &FlattenedEnum::A(1, MyOtherType(5)), + ); + check_encode_decode_schema(&FlattenEnum::B(0, (7,))); + check_encode_identically(&FlattenEnum::B(0, (7,)), &FlattenedEnum::B(7)); + check_encode_decode_schema(&FlattenEnum::C(MyInnerStruct { + hello: "howdy".to_string(), + world: MyInnerInnerType(13), + })); + check_encode_identically( + &FlattenEnum::C(MyInnerStruct { + hello: "howdy".to_string(), + world: MyInnerInnerType(13), + }), + &FlattenedEnum::C { + hello: "howdy".to_string(), + world: MyInnerInnerType(13), + }, + ); + check_encode_decode_schema(&FlattenEnum::D); + check_encode_identically(&FlattenEnum::D, &FlattenedEnum::D); + check_encode_decode_schema(&FlattenEnum::E(MyOtherTypeTwo(7))); + check_encode_identically( + &FlattenEnum::E(MyOtherTypeTwo(7)), + &FlattenedEnum::E(MyOtherTypeTwo(7)), + ); check_schema_equality::(); + check_schema_equality::(); check_encode_identically( &Mixed::C { diff --git a/sbor-tests/tests/sbor.rs b/sbor-tests/tests/sbor.rs index 29f10b98596..4cabeca2a2d 100644 --- a/sbor-tests/tests/sbor.rs +++ b/sbor-tests/tests/sbor.rs @@ -2,6 +2,9 @@ use sbor::*; +#[derive(Sbor)] +pub struct UnitStruct; + #[derive(Sbor)] pub struct TestStructNamed { pub state: u32, diff --git a/sbor/src/schema/type_aggregator.rs b/sbor/src/schema/type_aggregator.rs index 673c27fe719..5221e38e966 100644 --- a/sbor/src/schema/type_aggregator.rs +++ b/sbor/src/schema/type_aggregator.rs @@ -162,7 +162,7 @@ impl> TypeAggregator { /// Also tracks it as a named root type, which can be used e.g. in schema comparisons. /// /// This is only intended for use when adding root types to schemas, - /// /and should not be called from inside Describe macros. + /// and should not be called from inside `Describe` implementations. pub fn add_root_type + ?Sized>( &mut self, name: impl Into, @@ -183,7 +183,7 @@ impl> TypeAggregator { /// /// If the type is well known or already in the aggregator, this returns early with the existing index. /// - /// Typically you should use [`add_schema_descendents`], unless you're replacing/mutating + /// Typically you should use [`add_child_type_and_descendents`], unless you're replacing/mutating /// the child types somehow. In which case, you'll likely wish to call [`add_child_type`] and /// [`add_schema_descendents`] separately. ///