diff --git a/Cargo.lock b/Cargo.lock index f3529a8f..7159c2c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -131,6 +131,25 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "async-or" +version = "0.0.1" +dependencies = [ + "async-or", + "async-or-impl", + "tokio", + "trybuild", +] + +[[package]] +name = "async-or-impl" +version = "0.0.1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.31", +] + [[package]] name = "async-recursion" version = "1.0.5" @@ -174,6 +193,15 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "basic-toml" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bfc506e7a2370ec239e1d072507b2a80c833083699d3c6fa176fbb4de8448c6" +dependencies = [ + "serde", +] + [[package]] name = "bindgen" version = "0.65.1" @@ -525,6 +553,12 @@ dependencies = [ "spin 0.9.8", ] +[[package]] +name = "dissimilar" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86e3bdc80eee6e16b2b6b0f87fbc98c04bee3455e35174c0de1a125d0688c632" + [[package]] name = "either" version = "1.9.0" @@ -888,6 +922,7 @@ dependencies = [ name = "idekm" version = "0.1.0" dependencies = [ + "async-or", "codec", "conquer-once", "spdmlib", @@ -1075,6 +1110,7 @@ dependencies = [ name = "mctp_transport" version = "0.1.0" dependencies = [ + "async-or", "async-trait", "codec", "executor", @@ -1245,6 +1281,7 @@ dependencies = [ name = "pcidoe_transport" version = "0.1.0" dependencies = [ + "async-or", "async-trait", "codec", "futures", @@ -1597,6 +1634,7 @@ dependencies = [ name = "spdm-emu" version = "0.1.0" dependencies = [ + "async-or", "async-recursion", "async-trait", "bytes", @@ -1619,6 +1657,7 @@ dependencies = [ name = "spdm-requester-emu" version = "0.1.0" dependencies = [ + "async-or", "codec", "executor", "futures", @@ -1638,6 +1677,7 @@ dependencies = [ name = "spdm-responder-emu" version = "0.1.0" dependencies = [ + "async-or", "codec", "executor", "futures", @@ -1658,6 +1698,7 @@ dependencies = [ name = "spdmlib" version = "0.1.0" dependencies = [ + "async-or", "async-trait", "bit_field", "bitflags 1.3.2", @@ -1766,6 +1807,7 @@ dependencies = [ name = "tdisp" version = "0.2.0" dependencies = [ + "async-or", "bitflags 1.3.2", "codec", "conquer-once", @@ -1875,6 +1917,22 @@ dependencies = [ "syn 2.0.31", ] +[[package]] +name = "trybuild" +version = "1.0.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "196a58260a906cedb9bf6d8034b6379d0c11f552416960452f267402ceeddff1" +dependencies = [ + "basic-toml", + "dissimilar", + "glob", + "once_cell", + "serde", + "serde_derive", + "serde_json", + "termcolor", +] + [[package]] name = "unicode-ident" version = "1.0.11" diff --git a/Cargo.toml b/Cargo.toml index 597c7708..2651560d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ default-members = [ "spdmlib", "codec", + "async-or", "executor", "sys_time", "test/spdm-requester-emu", @@ -12,6 +13,7 @@ default-members = [ members = [ "spdmlib", "codec", + "async-or", "executor", "sys_time", "idekm", diff --git a/async-or/Cargo.toml b/async-or/Cargo.toml new file mode 100644 index 00000000..9b6d773c --- /dev/null +++ b/async-or/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "async-or" +version = "0.0.1" +edition = "2018" +authors = ["Longlong Yang "] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[[test]] +name = "tests" +path = "tests/progress.rs" + +[dev-dependencies] +trybuild = { version = "1.0.49", features = ["diff"] } +tokio = { version = "1.30.0", features = ["full"] } +async-or-impl = { path = "impl", features = ["async"] } +async-or = { path = ".", features = ["async"]} + +[dependencies] +async-or-impl = { path = "impl" } + +[features] +default = [] +async = ["async-or-impl/async"] \ No newline at end of file diff --git a/async-or/impl/Cargo.toml b/async-or/impl/Cargo.toml new file mode 100644 index 00000000..5b184622 --- /dev/null +++ b/async-or/impl/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "async-or-impl" +version = "0.0.1" +edition = "2018" +publish = false + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2.0", features = ["full", "extra-traits"]} +quote = "1.0" +proc-macro2 = { version = "1.0.66", default-features = false } + +[features] +default = [] +async = [] \ No newline at end of file diff --git a/async-or/impl/src/lib.rs b/async-or/impl/src/lib.rs new file mode 100644 index 00000000..26e16384 --- /dev/null +++ b/async-or/impl/src/lib.rs @@ -0,0 +1,115 @@ +// Copyright (c) 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 or MIT +use proc_macro::TokenStream; + +#[proc_macro_attribute] +#[cfg(feature = "async")] +pub fn async_or(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut input: syn::ItemFn = syn::parse_macro_input!(input as syn::ItemFn); + + input.sig.asyncness = Some(syn::Token![async](proc_macro2::Span::call_site())); + + use quote::ToTokens; + input.to_token_stream().into() +} + +#[proc_macro_attribute] +#[cfg(not(feature = "async"))] +pub fn async_or(_args: TokenStream, input: TokenStream) -> TokenStream { + input +} + +fn is_ident_present_in_attr(attr: &syn::Attribute, ident: &str) -> bool { + match &attr.meta { + syn::Meta::Path(path) => path.is_ident(ident), + syn::Meta::List(ml) => ml.path.is_ident(ident), + syn::Meta::NameValue(nv) => nv.path.is_ident(ident), + } +} + +#[proc_macro_attribute] +#[cfg(feature = "async")] +pub fn async_trait_or(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut input: syn::ItemTrait = syn::parse_macro_input!(input as syn::ItemTrait); + for item in input.items.iter_mut() { + if let syn::TraitItem::Fn(trait_item_fn) = item { + for (index, attr) in trait_item_fn.attrs.iter().enumerate() { + if is_ident_present_in_attr(attr, "async_or") { + trait_item_fn.sig.asyncness = + Some(syn::Token![async](proc_macro2::Span::call_site())); + trait_item_fn.attrs.remove(index); + break; + } + } + } + } + + TokenStream::from(quote::quote! { + #[async_trait::async_trait] + #input + }) +} + +#[proc_macro_attribute] +#[cfg(not(feature = "async"))] +pub fn async_trait_or(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut input: syn::ItemTrait = syn::parse_macro_input!(input as syn::ItemTrait); + for item in input.items.iter_mut() { + if let syn::TraitItem::Fn(trait_item_fn) = item { + for (index, attr) in trait_item_fn.attrs.iter().enumerate() { + if is_ident_present_in_attr(attr, "async_or") { + trait_item_fn.attrs.remove(index); + break; + } + } + } + } + + TokenStream::from(quote::quote! { + #input + }) +} + +#[proc_macro_attribute] +#[cfg(feature = "async")] +pub fn async_impl_or(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut input: syn::ItemImpl = syn::parse_macro_input!(input as syn::ItemImpl); + for item in input.items.iter_mut() { + if let syn::ImplItem::Fn(impl_item_fn) = item { + for (index, attr) in impl_item_fn.attrs.iter().enumerate() { + if is_ident_present_in_attr(attr, "async_or") { + impl_item_fn.sig.asyncness = + Some(syn::Token![async](proc_macro2::Span::call_site())); + impl_item_fn.attrs.remove(index); + break; + } + } + } + } + + TokenStream::from(quote::quote! { + #[async_trait::async_trait] + #input + }) +} + +#[proc_macro_attribute] +#[cfg(not(feature = "async"))] +pub fn async_impl_or(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut input: syn::ItemImpl = syn::parse_macro_input!(input as syn::ItemImpl); + for item in input.items.iter_mut() { + if let syn::ImplItem::Fn(impl_item_fn) = item { + for (index, attr) in impl_item_fn.attrs.iter().enumerate() { + if is_ident_present_in_attr(attr, "async_or") { + impl_item_fn.attrs.remove(index); + break; + } + } + } + } + + TokenStream::from(quote::quote! { + #input + }) +} diff --git a/async-or/src/lib.rs b/async-or/src/lib.rs new file mode 100644 index 00000000..33d20393 --- /dev/null +++ b/async-or/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright (c) 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 or MIT +#![no_std] + +pub use async_or_impl::async_impl_or; +pub use async_or_impl::async_or; +pub use async_or_impl::async_trait_or; + +#[cfg(feature = "async")] +#[macro_export] +macro_rules! await_or { + ($e:expr) => { + $e.await + }; +} + +#[cfg(not(feature = "async"))] +#[macro_export] +macro_rules! await_or { + ($e:expr) => { + $e + }; +} diff --git a/async-or/tests/01-fun-to-async.rs b/async-or/tests/01-fun-to-async.rs new file mode 100644 index 00000000..1f577f68 --- /dev/null +++ b/async-or/tests/01-fun-to-async.rs @@ -0,0 +1,15 @@ +// Copyright (c) 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 or MIT + +use async_or_impl::*; + +#[async_or] +pub fn test() { + panic!("This test function should be async-able, So this panic should be not called w/out async runtime!"); +} + +#[allow(unused_must_use)] +fn main() { + test(); +} diff --git a/async-or/tests/02-fun-call-to-await.rs b/async-or/tests/02-fun-call-to-await.rs new file mode 100644 index 00000000..bd49afa4 --- /dev/null +++ b/async-or/tests/02-fun-call-to-await.rs @@ -0,0 +1,30 @@ +// Copyright (c) 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 or MIT + +use async_or::await_or; +use async_or_impl::*; +use std::panic::{catch_unwind, AssertUnwindSafe}; + +#[async_or] +fn test_to_panic() { + panic!("This test function is set to panic!"); +} + +#[async_or] +pub fn test() { + await_or!(test_to_panic()); +} + +fn main() { + let rt = tokio::runtime::Runtime::new().unwrap(); + + let msg = catch_unwind(AssertUnwindSafe(|| { + rt.block_on(test()); + })); + + assert_eq!( + "This test function is set to panic!", + *msg.unwrap_err().downcast_ref::<&str>().unwrap() + ); +} diff --git a/async-or/tests/progress.rs b/async-or/tests/progress.rs new file mode 100644 index 00000000..ee546c3d --- /dev/null +++ b/async-or/tests/progress.rs @@ -0,0 +1,10 @@ +// Copyright (c) 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 or MIT + +#[test] +fn tests() { + let t = trybuild::TestCases::new(); + t.pass("tests/01-fun-to-async.rs"); + t.pass("tests/02-fun-call-to-await.rs"); +} diff --git a/fuzz-target/fuzzlib/Cargo.toml b/fuzz-target/fuzzlib/Cargo.toml index 17ac3237..caedb7c2 100644 --- a/fuzz-target/fuzzlib/Cargo.toml +++ b/fuzz-target/fuzzlib/Cargo.toml @@ -8,7 +8,7 @@ edition = "2018" [dependencies] afl = { version = "=0.12.12", optional = true } -spdmlib = { path = "../../spdmlib", default-features = false, features=["spdm-ring"] } +spdmlib = { path = "../../spdmlib", default-features = false, features=["spdm-ring", "async"] } simple_logger = "4.2.0" log = "0.4.13" ring = { version = "0.17.6" } diff --git a/idekm/Cargo.toml b/idekm/Cargo.toml index 463e2cd4..0d8c7313 100644 --- a/idekm/Cargo.toml +++ b/idekm/Cargo.toml @@ -18,6 +18,8 @@ codec = { path = "../codec" } zeroize = { version = "1.5.0", features = ["zeroize_derive"]} spdmlib = { path = "../spdmlib", default-features = false, features = ["spdm-ring"]} conquer-once = { version = "0.3.2", default-features = false } +async-or = { path = "../async-or" } [features] +async = ["spdmlib/async", "async-or/async"] \ No newline at end of file diff --git a/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_prog.rs b/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_prog.rs index 1958cfc9..1af7223a 100644 --- a/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_prog.rs +++ b/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_prog.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SpdmResult; @@ -21,7 +23,8 @@ use super::IdekmReqContext; impl IdekmReqContext { #[allow(clippy::too_many_arguments)] - pub async fn pci_ide_km_key_prog( + #[async_or] + pub fn pci_ide_km_key_prog( &mut self, // IN spdm_requester: &mut RequesterContext, @@ -55,14 +58,13 @@ impl IdekmReqContext { .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let kp_ack_data_object = KpAckDataObject::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_go.rs b/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_go.rs index 38d2e5ad..cc361faf 100644 --- a/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_go.rs +++ b/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_go.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -21,7 +23,8 @@ use super::IdekmReqContext; impl IdekmReqContext { #[allow(clippy::too_many_arguments)] - pub async fn pci_ide_km_key_set_go( + #[async_or] + pub fn pci_ide_km_key_set_go( &mut self, // IN spdm_requester: &mut RequesterContext, @@ -51,14 +54,13 @@ impl IdekmReqContext { .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let kgo_stop_ack_data_object = KGoStopAckDataObject::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_stop.rs b/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_stop.rs index e23c85e7..ee9ca82c 100644 --- a/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_stop.rs +++ b/idekm/src/pci_ide_km_requester/pci_ide_km_req_key_set_stop.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -21,7 +23,8 @@ use super::IdekmReqContext; impl IdekmReqContext { #[allow(clippy::too_many_arguments)] - pub async fn pci_ide_km_key_set_stop( + #[async_or] + pub fn pci_ide_km_key_set_stop( &mut self, // IN spdm_requester: &mut RequesterContext, @@ -51,14 +54,13 @@ impl IdekmReqContext { .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let kgo_stop_ack_data_object = KGoStopAckDataObject::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/idekm/src/pci_ide_km_requester/pci_ide_km_req_query.rs b/idekm/src/pci_ide_km_requester/pci_ide_km_req_query.rs index a491df80..f05303eb 100644 --- a/idekm/src/pci_ide_km_requester/pci_ide_km_req_query.rs +++ b/idekm/src/pci_ide_km_requester/pci_ide_km_req_query.rs @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; use codec::{Codec, Writer}; use spdmlib::{ error::{SpdmResult, SPDM_STATUS_BUFFER_FULL, SPDM_STATUS_INVALID_MSG_FIELD}, @@ -18,7 +19,8 @@ use super::IdekmReqContext; impl IdekmReqContext { #[allow(clippy::too_many_arguments)] - pub async fn pci_ide_km_query( + #[async_or] + pub fn pci_ide_km_query( &mut self, // IN spdm_requester: &mut RequesterContext, @@ -45,14 +47,13 @@ impl IdekmReqContext { .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let query_resp_data_object = QueryRespDataObject::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/mctp_transport/Cargo.toml b/mctp_transport/Cargo.toml index dd2819b9..ccc8184d 100644 --- a/mctp_transport/Cargo.toml +++ b/mctp_transport/Cargo.toml @@ -16,3 +16,7 @@ spdmlib = { path = "../spdmlib", default-features = false} futures = { version = "0.3", default-features = false } async-trait = "0.1.71" executor = { path = "../executor" } +async-or = { path = "../async-or" } + +[features] +async = ["spdmlib/async", "async-or/async"] \ No newline at end of file diff --git a/mctp_transport/src/header.rs b/mctp_transport/src/header.rs index 9a4e3dc4..3b109d5c 100644 --- a/mctp_transport/src/header.rs +++ b/mctp_transport/src/header.rs @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT -use async_trait::async_trait; +use async_or::async_impl_or; use codec::enum_builder; use codec::{Codec, Reader, Writer}; use spdmlib::common::SpdmTransportEncap; @@ -57,9 +57,10 @@ impl Codec for MctpMessageHeader { #[derive(Debug, Copy, Clone, Default)] pub struct MctpTransportEncap {} -#[async_trait] +#[async_impl_or] impl SpdmTransportEncap for MctpTransportEncap { - async fn encap( + #[async_or] + fn encap( &mut self, spdm_buffer: Arc<&[u8]>, transport_buffer: Arc>, @@ -87,7 +88,8 @@ impl SpdmTransportEncap for MctpTransportEncap { Ok(header_size + payload_len) } - async fn decap( + #[async_or] + fn decap( &mut self, transport_buffer: Arc<&[u8]>, spdm_buffer: Arc>, @@ -119,7 +121,8 @@ impl SpdmTransportEncap for MctpTransportEncap { Ok((payload_size, secured_message)) } - async fn encap_app( + #[async_or] + fn encap_app( &mut self, spdm_buffer: Arc<&[u8]>, app_buffer: Arc>, @@ -149,7 +152,8 @@ impl SpdmTransportEncap for MctpTransportEncap { Ok(header_size + payload_len) } - async fn decap_app( + #[async_or] + fn decap_app( &mut self, app_buffer: Arc<&[u8]>, spdm_buffer: Arc>, diff --git a/pcidoe_transport/Cargo.toml b/pcidoe_transport/Cargo.toml index 7f6e4a6e..af3e4cd8 100644 --- a/pcidoe_transport/Cargo.toml +++ b/pcidoe_transport/Cargo.toml @@ -15,3 +15,7 @@ spdmlib = { path = "../spdmlib", default-features = false} futures = { version = "0.3", default-features = false } async-trait = "0.1.71" spin = { version = "0.9.8" } +async-or = { path = "../async-or" } + +[features] +async = ["spdmlib/async", "async-or/async"] \ No newline at end of file diff --git a/pcidoe_transport/src/header.rs b/pcidoe_transport/src/header.rs index 072792db..2cf13895 100644 --- a/pcidoe_transport/src/header.rs +++ b/pcidoe_transport/src/header.rs @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT -use async_trait::async_trait; +use async_or::async_impl_or; extern crate alloc; use alloc::boxed::Box; use alloc::sync::Arc; @@ -89,9 +89,10 @@ impl Codec for PciDoeMessageHeader { #[derive(Debug, Copy, Clone, Default)] pub struct PciDoeTransportEncap {} -#[async_trait] +#[async_impl_or] impl SpdmTransportEncap for PciDoeTransportEncap { - async fn encap( + #[async_or] + fn encap( &mut self, spdm_buffer: Arc<&[u8]>, transport_buffer: Arc>, @@ -122,7 +123,8 @@ impl SpdmTransportEncap for PciDoeTransportEncap { Ok(header_size + aligned_payload_len) } - async fn decap( + #[async_or] + fn decap( &mut self, transport_buffer: Arc<&[u8]>, spdm_buffer: Arc>, @@ -154,7 +156,8 @@ impl SpdmTransportEncap for PciDoeTransportEncap { Ok((payload_size, secured_message)) } - async fn encap_app( + #[async_or] + fn encap_app( &mut self, spdm_buffer: Arc<&[u8]>, app_buffer: Arc>, @@ -166,7 +169,8 @@ impl SpdmTransportEncap for PciDoeTransportEncap { Ok(spdm_buffer.len()) } - async fn decap_app( + #[async_or] + fn decap_app( &mut self, app_buffer: Arc<&[u8]>, spdm_buffer: Arc>, diff --git a/readme.md b/readme.md index 9d7f9eb0..240275cf 100644 --- a/readme.md +++ b/readme.md @@ -122,12 +122,18 @@ cargo fmt cargo build ``` -### Build `no_std` spdm +### Build sync `no_std` spdm ``` pushd spdmlib cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring" ``` +### Build async `no_std` spdm +``` +pushd spdmlib +cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring,async" +``` + ## Run Rust SPDM emulator ### Run emulator with default feature @@ -147,19 +153,25 @@ cargo run -p spdm-requester-emu --no-default-features --features "spdm-ring,hash The following list shows the supported combinations for both spdm-requester-emu and spdm-responder-emu -| Features | CryptoLibrary | Hashed transcript data support | async runtime | notes | -|-------------------------------------------------|---------------|--------------------------------|---------------|---------------------------------------------------------------------------------------------------| -| spdm-ring,async-executor | ring | No | executor | use ring as crypto library with hashed-transcript-data disabled, use executor as async runtime | -| spdm-ring,hashed-transcript-data,async-executor | ring | Yes | executor | use ring as crypto library with hashed-transcript-data enabled, use executor as async runtime | -| spdm-ring,hashed-transcript-data,async-tokio | ring | Yes | tokio | use ring as crypto library with hashed-transcript-data enabled, use tokio as async runtime | -| spdm-mbedtls,async-executor | mbedtls | No | executor | use mbedtls as crypto library with hashed-transcript-data disabled, use executor as async runtime | -| spdm-mbedtls,hashed-transcript-data,async-executor | mbedtls | Yes | executor | use mbedtls as crypto library with hashed-transcript-data enabled, use executor as async runtime | -| spdm-mbedtls,hashed-transcript-data,async-tokio | mbedtls | Yes | tokio | use mbedtls as crypto library with hashed-transcript-data enabled, use tokio as async runtime | +| Features | CryptoLibrary | Hashed transcript data support | sync/async | notes | +|----------------------------------------------------|---------------|--------------------------------|------------------------|-----------------------------------------------------------------------------------------------------------------| +| spdm-ring | ring | No | sync | use ring as crypto library with hashed-transcript-data disabled, sync version. | +| spdm-ring,hashed-transcript-data | ring | Yes | sync | use ring as crypto library with hashed-transcript-data enabled, sync version. | +| spdm-ring,hashed-transcript-data,async-tokio | ring | Yes | tokio async runtime | use ring as crypto library with hashed-transcript-data enabled, async version, use tokio as async runtime | +| spdm-mbedtls | mbedtls | No | sync | use mbedtls as crypto library with hashed-transcript-data disabled, sync version. | +| spdm-mbedtls,hashed-transcript-data | mbedtls | Yes | sync | use mbedtls as crypto library with hashed-transcript-data enabled, sync version. | +| spdm-mbedtls,hashed-transcript-data,async-executor | mbedtls | Yes | executor async runtime | use mbedtls as crypto library with hashed-transcript-data enabled, async version, use executor as async runtime | For example, run the emulator with spdm-ring enabled and without hashed-transcript-data enabled, and use executor as async runtime. Open one command windows and run: ``` -cargo run -p spdm-responder-emu --no-default-features --features "spdm-ring,async-executor " +cargo run -p spdm-responder-emu --no-default-features --features "spdm-ring,async-executor" +``` + +run the emulator with spdm-ring enabled and without hashed-transcript-data enabled, and use sync version. +Open one command windows and run: +``` +cargo run -p spdm-responder-emu --no-default-features --features "spdm-ring" ``` run the emulator with spdm-mbedtls enabled and with hashed-transcript-data enabled, and use tokio as async runtime. @@ -194,14 +206,14 @@ spdm_responder_emu.exe --trans PCI_DOE 2. run rust-spdm-emu as requester: ``` -cargo run -p spdm-requester-emu --no-default-features --features "spdm-ring,hashed-transcript-data,async-executor " +cargo run -p spdm-requester-emu --no-default-features --features "spdm-ring,hashed-transcript-data,async-executor" ``` Test rust-spdm as responder: 1. run rust-spdm-emu as Test rust-spdm as responder: ``` -cargo run -p spdm-responder-emu --no-default-features --features "spdm-ring,hashed-transcript-data,async-executor " +cargo run -p spdm-responder-emu --no-default-features --features "spdm-ring,hashed-transcript-data,async-executor" ``` 2. run libspdm in spdm-emu as requester: diff --git a/sh_script/build.sh b/sh_script/build.sh index 4ebd4fe1..8d83937d 100755 --- a/sh_script/build.sh +++ b/sh_script/build.sh @@ -54,25 +54,43 @@ build() { echo "Building Rust-SPDM with spdm-ring feature..." echo_command cargo build --release --no-default-features --features=spdm-ring + + echo "Building Rust-SPDM with spdm-ring,async feature..." + echo_command cargo build --release --no-default-features --features=spdm-ring,async echo "Building Rust-SPDM with spdm-ring,hashed-transcript-data feature..." echo_command cargo build --release --no-default-features --features=spdm-ring,hashed-transcript-data + + echo "Building Rust-SPDM with spdm-ring,hashed-transcript-data,async feature..." + echo_command cargo build --release --no-default-features --features=spdm-ring,hashed-transcript-data,async echo "Building Rust-SPDM with spdm-ring,hashed-transcript-data,mut-auth feature..." echo_command cargo build --release --no-default-features --features=spdm-ring,hashed-transcript-data,mut-auth + echo "Building Rust-SPDM with spdm-ring,hashed-transcript-data,mut-auth,async feature..." + echo_command cargo build --release --no-default-features --features=spdm-ring,hashed-transcript-data,mut-auth,async + if [ -z "$RUSTFLAGS" ]; then echo "Building Rust-SPDM in no std with no-default-features..." echo_command cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features echo "Building Rust-SPDM in no std with spdm-ring feature..." echo_command cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring" - + + echo "Building Rust-SPDM in no std with spdm-ring,async feature..." + echo_command cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring,async" + echo "Building Rust-SPDM in no std with spdm-ring,hashed-transcript-data feature..." echo_command cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring,hashed-transcript-data" + + echo "Building Rust-SPDM in no std with spdm-ring,hashed-transcript-data,async feature..." + echo_command cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring,hashed-transcript-data,async" echo "Building Rust-SPDM in no std with spdm-ring,hashed-transcript-data,mut-auth feature..." echo_command cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring,hashed-transcript-data,mut-auth" + + echo "Building Rust-SPDM in no std with spdm-ring,hashed-transcript-data,mut-auth,async feature..." + echo_command cargo build -Z build-std=core,alloc,compiler_builtins --target x86_64-unknown-none --release --no-default-features --features="spdm-ring,hashed-transcript-data,mut-auth,async" fi popd @@ -84,11 +102,13 @@ build() { echo_command cargo build -p spdm-responder-emu } -RUN_REQUESTER_FEATURES=${RUN_REQUESTER_FEATURES:-spdm-ring,hashed-transcript-data,async-executor} -RUN_RESPONDER_FEATURES=${RUN_RESPONDER_FEATURES:-spdm-ring,hashed-transcript-data,async-executor} -RUN_REQUESTER_MUTAUTH_FEATURES="${RUN_REQUESTER_FEATURES},mut-auth" -RUN_RESPONDER_MUTAUTH_FEATURES="${RUN_RESPONDER_FEATURES},mut-auth" -RUN_RESPONDER_MANDATORY_MUTAUTH_FEATURES="${RUN_RESPONDER_FEATURES},mandatory-mut-auth" +RUN_REQUESTER_FEATURES=${RUN_REQUESTER_FEATURES:-spdm-ring,hashed-transcript-data} +RUN_RESPONDER_FEATURES=${RUN_RESPONDER_FEATURES:-spdm-ring,hashed-transcript-data} +RUN_REQUESTER_FEATURES_WITH_ASYNC=${RUN_REQUESTER_FEATURES:-spdm-ring,hashed-transcript-data,async-executor} +RUN_RESPONDER_FEATURES_WITH_ASYNC=${RUN_RESPONDER_FEATURES:-spdm-ring,hashed-transcript-data,async-executor} +RUN_REQUESTER_MUTAUTH_FEATURES="${RUN_REQUESTER_FEATURES_WITH_ASYNC},mut-auth" +RUN_RESPONDER_MUTAUTH_FEATURES="${RUN_RESPONDER_FEATURES_WITH_ASYNC},mut-auth" +RUN_RESPONDER_MANDATORY_MUTAUTH_FEATURES="${RUN_RESPONDER_FEATURES_WITH_ASYNC},mandatory-mut-auth" run_with_spdm_emu() { echo "Running with spdm-emu..." @@ -138,7 +158,7 @@ run_with_spdm_emu_mandatory_mut_auth() { run_basic_test() { echo "Running basic tests..." echo_command cargo test -- --test-threads=1 - echo_command cargo test --no-default-features --features "spdmlib/std,spdmlib/spdm-ring" -- --test-threads=1 + echo_command cargo test --no-default-features --features "spdmlib/std,spdmlib/spdm-ring,async" -- --test-threads=1 echo "Running basic tests finished..." echo "Running spdmlib-test..." @@ -156,6 +176,14 @@ run_rust_spdm_emu() { cleanup } +run_async_rust_spdm_emu() { + echo "Running requester and responder..." + echo_command cargo run -p spdm-responder-emu --no-default-features --features="$RUN_RESPONDER_FEATURES_WITH_ASYNC" & + sleep 20 + echo_command cargo run -p spdm-requester-emu --no-default-features --features="$RUN_REQUESTER_FEATURES_WITH_ASYNC" + cleanup +} + run_rust_spdm_emu_mut_auth() { echo "Running requester and responder mutual authentication..." echo $RUN_REQUESTER_MUTAUTH_FEATURES @@ -177,6 +205,7 @@ run_rust_spdm_emu_mandatory_mut_auth() { run() { run_basic_test run_rust_spdm_emu + run_async_rust_spdm_emu run_rust_spdm_emu_mut_auth run_rust_spdm_emu_mandatory_mut_auth } diff --git a/spdmlib/Cargo.toml b/spdmlib/Cargo.toml index bac27f3d..d33500cc 100644 --- a/spdmlib/Cargo.toml +++ b/spdmlib/Cargo.toml @@ -11,6 +11,7 @@ edition = "2018" [dependencies] codec = {path= "../codec"} +async-or = { path = "../async-or" } bitflags = "1.2.1" log = "0.4.13" bytes = { version="1", default-features=false } @@ -48,3 +49,4 @@ downcast = [] hashed-transcript-data = [] mut-auth = [] mandatory-mut-auth = ["mut-auth"] +async = ["async-or/async"] diff --git a/spdmlib/src/common/mod.rs b/spdmlib/src/common/mod.rs index 1eb1f71f..32154e72 100644 --- a/spdmlib/src/common/mod.rs +++ b/spdmlib/src/common/mod.rs @@ -9,13 +9,14 @@ pub mod spdm_codec; use crate::message::SpdmRequestResponseCode; use crate::{crypto, protocol::*}; +use async_or::{async_or, async_trait_or, await_or}; use spin::Mutex; extern crate alloc; +#[allow(unused_imports)] use alloc::boxed::Box; use alloc::sync::Arc; use core::ops::DerefMut; -use async_trait::async_trait; pub use opaque::*; pub use spdm_codec::SpdmCodec; @@ -78,17 +79,16 @@ pub const INITIAL_SESSION_ID: u16 = 0xFFFD; pub const INVALID_HALF_SESSION_ID: u16 = 0x0; pub const INVALID_SESSION_ID: u32 = 0x0; -#[async_trait] +#[async_trait_or] pub trait SpdmDeviceIo { - async fn send(&mut self, buffer: Arc<&[u8]>) -> SpdmResult; + #[async_or] + fn send(&mut self, buffer: Arc<&[u8]>) -> SpdmResult; - async fn receive( - &mut self, - buffer: Arc>, - timeout: usize, - ) -> Result; + #[async_or] + fn receive(&mut self, buffer: Arc>, timeout: usize) -> Result; - async fn flush_all(&mut self) -> SpdmResult; + #[async_or] + fn flush_all(&mut self) -> SpdmResult; #[cfg(feature = "downcast")] fn as_any(&mut self) -> &mut dyn Any; @@ -96,29 +96,33 @@ pub trait SpdmDeviceIo { use core::fmt::Debug; -#[async_trait] +#[async_trait_or] pub trait SpdmTransportEncap { - async fn encap( + #[async_or] + fn encap( &mut self, spdm_buffer: Arc<&[u8]>, transport_buffer: Arc>, secured_message: bool, ) -> SpdmResult; - async fn decap( + #[async_or] + fn decap( &mut self, transport_buffer: Arc<&[u8]>, spdm_buffer: Arc>, ) -> SpdmResult<(usize, bool)>; - async fn encap_app( + #[async_or] + fn encap_app( &mut self, spdm_buffer: Arc<&[u8]>, app_buffer: Arc>, is_app_message: bool, ) -> SpdmResult; - async fn decap_app( + #[async_or] + fn decap_app( &mut self, app_buffer: Arc<&[u8]>, spdm_buffer: Arc>, @@ -967,23 +971,19 @@ impl SpdmContext { } } - pub async fn encap( - &mut self, - send_buffer: &[u8], - transport_buffer: &mut [u8], - ) -> SpdmResult { + #[async_or] + pub fn encap(&mut self, send_buffer: &[u8], transport_buffer: &mut [u8]) -> SpdmResult { let mut transport_encap = self.transport_encap.lock(); let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = transport_encap.deref_mut(); let send_buffer = Arc::new(send_buffer); let transport_buffer = Mutex::new(transport_buffer); let transport_buffer = Arc::new(transport_buffer); - transport_encap - .encap(send_buffer, transport_buffer, false) - .await + await_or!(transport_encap.encap(send_buffer, transport_buffer, false)) } - pub async fn encode_secured_message( + #[async_or] + pub fn encode_secured_message( &mut self, session_id: u32, send_buffer: &[u8], @@ -999,9 +999,7 @@ impl SpdmContext { let send_buffer = Arc::new(send_buffer); let app_buffer = Mutex::new(&mut app_buffer[..]); let app_buffer = Arc::new(app_buffer); - transport_encap - .encap_app(send_buffer, app_buffer, is_app_message) - .await? + await_or!(transport_encap.encap_app(send_buffer, app_buffer, is_app_message))? }; let spdm_session = self @@ -1018,16 +1016,15 @@ impl SpdmContext { let mut transport_encap = self.transport_encap.lock(); let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = transport_encap.deref_mut(); - transport_encap - .encap( - Arc::new(&encoded_send_buffer[..encode_size]), - Arc::new(Mutex::new(transport_buffer)), - true, - ) - .await + await_or!(transport_encap.encap( + Arc::new(&encoded_send_buffer[..encode_size]), + Arc::new(Mutex::new(transport_buffer)), + true, + )) } - pub async fn decap( + #[async_or] + pub fn decap( &mut self, transport_buffer: &[u8], receive_buffer: &mut [u8], @@ -1036,12 +1033,10 @@ impl SpdmContext { let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = transport_encap.deref_mut(); - let (used, secured_message) = transport_encap - .decap( - Arc::new(transport_buffer), - Arc::new(Mutex::new(receive_buffer)), - ) - .await?; + let (used, secured_message) = await_or!(transport_encap.decap( + Arc::new(transport_buffer), + Arc::new(Mutex::new(receive_buffer)), + ))?; if secured_message { return Err(SPDM_STATUS_DECAP_FAIL); //need check @@ -1050,7 +1045,8 @@ impl SpdmContext { Ok(used) } - pub async fn decode_secured_message( + #[async_or] + pub fn decode_secured_message( &mut self, session_id: u32, transport_buffer: &[u8], @@ -1063,12 +1059,10 @@ impl SpdmContext { let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = transport_encap.deref_mut(); - transport_encap - .decap( - Arc::new(transport_buffer), - Arc::new(Mutex::new(&mut encoded_receive_buffer)), - ) - .await? + await_or!(transport_encap.decap( + Arc::new(transport_buffer), + Arc::new(Mutex::new(&mut encoded_receive_buffer)), + ))? }; if !secured_message { @@ -1090,12 +1084,10 @@ impl SpdmContext { let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = transport_encap.deref_mut(); - let used = transport_encap - .decap_app( - Arc::new(&app_buffer[0..decode_size]), - Arc::new(Mutex::new(receive_buffer)), - ) - .await?; + let used = await_or!(transport_encap.decap_app( + Arc::new(&app_buffer[0..decode_size]), + Arc::new(Mutex::new(receive_buffer)), + ))?; Ok(used.0) } diff --git a/spdmlib/src/requester/challenge_req.rs b/spdmlib/src/requester/challenge_req.rs index 88f97415..40c255d5 100644 --- a/spdmlib/src/requester/challenge_req.rs +++ b/spdmlib/src/requester/challenge_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::crypto; #[cfg(feature = "hashed-transcript-data")] use crate::error::SPDM_STATUS_INVALID_STATE_LOCAL; @@ -14,7 +16,8 @@ use crate::protocol::*; use crate::requester::*; impl RequesterContext { - pub async fn send_receive_spdm_challenge( + #[async_or] + pub fn send_receive_spdm_challenge( &mut self, slot_id: u8, measurement_summary_hash_type: SpdmMeasurementSummaryHashType, @@ -31,14 +34,11 @@ impl RequesterContext { let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let send_used = self.encode_spdm_challenge(slot_id, measurement_summary_hash_type, &mut send_buffer)?; - self.send_message(None, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(None, &send_buffer[..send_used], false))?; // Receive let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(None, &mut receive_buffer, true) - .await?; + let used = await_or!(self.receive_message(None, &mut receive_buffer, true))?; self.handle_spdm_challenge_response( 0, // NULL slot_id, diff --git a/spdmlib/src/requester/context.rs b/spdmlib/src/requester/context.rs index 0dc53b98..8b48400c 100644 --- a/spdmlib/src/requester/context.rs +++ b/spdmlib/src/requester/context.rs @@ -8,6 +8,7 @@ use crate::config; use crate::error::{SpdmResult, SPDM_STATUS_RECEIVE_FAIL, SPDM_STATUS_SEND_FAIL}; use crate::protocol::*; +use async_or::{async_or, await_or}; use spin::Mutex; extern crate alloc; use alloc::sync::Arc; @@ -34,28 +35,27 @@ impl RequesterContext { } } - pub async fn init_connection( - &mut self, - transcript_vca: &mut Option, - ) -> SpdmResult { + #[async_or] + pub fn init_connection(&mut self, transcript_vca: &mut Option) -> SpdmResult { *transcript_vca = None; - self.send_receive_spdm_version().await?; - self.send_receive_spdm_capability().await?; - self.send_receive_spdm_algorithm().await?; + await_or!(self.send_receive_spdm_version())?; + await_or!(self.send_receive_spdm_capability())?; + await_or!(self.send_receive_spdm_algorithm())?; *transcript_vca = Some(self.common.runtime_info.message_a.clone()); Ok(()) } - pub async fn start_session( + #[async_or] + pub fn start_session( &mut self, use_psk: bool, slot_id: u8, measurement_summary_hash_type: SpdmMeasurementSummaryHashType, ) -> SpdmResult { if !use_psk { - let session_id = self - .send_receive_spdm_key_exchange(slot_id, measurement_summary_hash_type) - .await?; + let session_id = await_or!( + self.send_receive_spdm_key_exchange(slot_id, measurement_summary_hash_type) + )?; #[cfg(not(feature = "mut-auth"))] let req_slot_id: Option = None; #[cfg(feature = "mut-auth")] @@ -71,30 +71,31 @@ impl RequesterContext { .req_capabilities_sel .contains(SpdmRequestCapabilityFlags::MUT_AUTH_CAP) { - self.session_based_mutual_authenticate(session_id).await?; + await_or!(self.session_based_mutual_authenticate(session_id))?; Some(self.common.runtime_info.get_local_used_cert_chain_slot_id()) } else { None } }; - self.send_receive_spdm_finish(req_slot_id, session_id) - .await?; + await_or!(self.send_receive_spdm_finish(req_slot_id, session_id))?; Ok(session_id) } else { - let session_id = self - .send_receive_spdm_psk_exchange(measurement_summary_hash_type, None) - .await?; - self.send_receive_spdm_psk_finish(session_id).await?; + let session_id = await_or!( + self.send_receive_spdm_psk_exchange(measurement_summary_hash_type, None) + )?; + await_or!(self.send_receive_spdm_psk_finish(session_id))?; Ok(session_id) } } - pub async fn end_session(&mut self, session_id: u32) -> SpdmResult { - self.send_receive_spdm_end_session(session_id).await + #[async_or] + pub fn end_session(&mut self, session_id: u32) -> SpdmResult { + await_or!(self.send_receive_spdm_end_session(session_id)) } - pub async fn send_message( + #[async_or] + pub fn send_message( &mut self, session_id: Option, send_buffer: &[u8], @@ -112,28 +113,25 @@ impl RequesterContext { let mut transport_buffer = [0u8; config::SENDER_BUFFER_SIZE]; let used = if let Some(session_id) = session_id { - self.common - .encode_secured_message( - session_id, - send_buffer, - &mut transport_buffer, - true, - is_app_message, - ) - .await? + await_or!(self.common.encode_secured_message( + session_id, + send_buffer, + &mut transport_buffer, + true, + is_app_message, + ))? } else { - self.common - .encap(send_buffer, &mut transport_buffer) - .await? + await_or!(self.common.encap(send_buffer, &mut transport_buffer))? }; let mut device_io = self.common.device_io.lock(); let device_io: &mut (dyn SpdmDeviceIo + Send + Sync) = device_io.deref_mut(); - device_io.send(Arc::new(&transport_buffer[..used])).await + await_or!(device_io.send(Arc::new(&transport_buffer[..used]))) } - pub async fn receive_message( + #[async_or] + pub fn receive_message( &mut self, session_id: Option, receive_buffer: &mut [u8], @@ -153,20 +151,18 @@ impl RequesterContext { let mut device_io = self.common.device_io.lock(); let device_io: &mut (dyn SpdmDeviceIo + Send + Sync) = device_io.deref_mut(); - device_io - .receive(Arc::new(Mutex::new(&mut transport_buffer)), timeout) - .await + await_or!(device_io.receive(Arc::new(Mutex::new(&mut transport_buffer)), timeout)) .map_err(|_| SPDM_STATUS_RECEIVE_FAIL)? }; if let Some(session_id) = session_id { - self.common - .decode_secured_message(session_id, &transport_buffer[..used], receive_buffer) - .await + await_or!(self.common.decode_secured_message( + session_id, + &transport_buffer[..used], + receive_buffer + )) } else { - self.common - .decap(&transport_buffer[..used], receive_buffer) - .await + await_or!(self.common.decap(&transport_buffer[..used], receive_buffer)) } } } diff --git a/spdmlib/src/requester/encap_req.rs b/spdmlib/src/requester/encap_req.rs index e3bc885a..bb739e2d 100644 --- a/spdmlib/src/requester/encap_req.rs +++ b/spdmlib/src/requester/encap_req.rs @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; use codec::{Codec, Reader, Writer}; use crate::{ @@ -26,7 +27,8 @@ use crate::{ use super::RequesterContext; impl RequesterContext { - pub async fn get_encapsulated_request_response( + #[async_or] + pub fn get_encapsulated_request_response( &mut self, session_id: u32, mut_auth_requested: SpdmKeyExchangeMutAuthAttributes, @@ -66,20 +68,20 @@ impl RequesterContext { ), }; let _ = get_digest_request.spdm_encode(&mut self.common, &mut writer)?; - self.process_encapsulated_request(session_id, 0, &encapsulated_request) - .await?; + await_or!(self.process_encapsulated_request(session_id, 0, &encapsulated_request))?; } _ => { - self.send_get_encapsulated_request(session_id).await?; - self.receive_encapsulated_request(session_id).await?; + await_or!(self.send_get_encapsulated_request(session_id))?; + await_or!(self.receive_encapsulated_request(session_id))?; } } - while self.receive_encapsulated_response_ack(session_id).await? {} + while await_or!(self.receive_encapsulated_response_ack(session_id))? {} Ok(()) } - pub async fn send_get_encapsulated_request(&mut self, session_id: u32) -> SpdmResult { + #[async_or] + pub fn send_get_encapsulated_request(&mut self, session_id: u32) -> SpdmResult { let mut send_buffer = [0u8; 4]; let mut writer = Writer::init(&mut send_buffer); let get_encap_request = SpdmMessage { @@ -93,15 +95,13 @@ impl RequesterContext { }; let _ = get_encap_request.spdm_encode(&mut self.common, &mut writer)?; - self.send_message(Some(session_id), writer.mut_used_slice(), false) - .await + await_or!(self.send_message(Some(session_id), writer.mut_used_slice(), false)) } - pub async fn receive_encapsulated_request(&mut self, session_id: u32) -> SpdmResult { + #[async_or] + pub fn receive_encapsulated_request(&mut self, session_id: u32) -> SpdmResult { let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let _ = self - .receive_message(Some(session_id), &mut receive_buffer, false) - .await?; + let _ = await_or!(self.receive_message(Some(session_id), &mut receive_buffer, false))?; let mut reader = Reader::init(&receive_buffer); let header = SpdmMessageHeader::read(&mut reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; @@ -116,19 +116,17 @@ impl RequesterContext { SpdmEncapsulatedRequestPayload::spdm_read(&mut self.common, &mut reader) .ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; - self.process_encapsulated_request( + await_or!(self.process_encapsulated_request( session_id, encapsulated_request.request_id, &receive_buffer[reader.used()..], - ) - .await + )) } - pub async fn receive_encapsulated_response_ack(&mut self, session_id: u32) -> SpdmResult { + #[async_or] + pub fn receive_encapsulated_response_ack(&mut self, session_id: u32) -> SpdmResult { let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let size = self - .receive_message(Some(session_id), &mut receive_buffer, false) - .await?; + let size = await_or!(self.receive_message(Some(session_id), &mut receive_buffer, false))?; let mut reader = Reader::init(&receive_buffer); let header = SpdmMessageHeader::read(&mut reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; @@ -169,17 +167,17 @@ impl RequesterContext { _ => {} } - self.process_encapsulated_request( + await_or!(self.process_encapsulated_request( session_id, ack_header.request_id, &receive_buffer[reader.used()..], - ) - .await?; + ))?; Ok(true) } - async fn process_encapsulated_request( + #[async_or] + fn process_encapsulated_request( &mut self, session_id: u32, request_id: u8, @@ -218,7 +216,6 @@ impl RequesterContext { ), } - self.send_message(Some(session_id), writer.used_slice(), false) - .await + await_or!(self.send_message(Some(session_id), writer.used_slice(), false)) } } diff --git a/spdmlib/src/requester/end_session_req.rs b/spdmlib/src/requester/end_session_req.rs index 9d363425..210e5a1b 100644 --- a/spdmlib/src/requester/end_session_req.rs +++ b/spdmlib/src/requester/end_session_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{ SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD, SPDM_STATUS_INVALID_PARAMETER, @@ -10,7 +12,8 @@ use crate::message::*; use crate::requester::*; impl RequesterContext { - pub async fn send_receive_spdm_end_session(&mut self, session_id: u32) -> SpdmResult { + #[async_or] + pub fn send_receive_spdm_end_session(&mut self, session_id: u32) -> SpdmResult { info!("send spdm end_session\n"); self.common.reset_buffer_via_request_code( @@ -20,13 +23,10 @@ impl RequesterContext { let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let used = self.encode_spdm_end_session(&mut send_buffer)?; - self.send_message(Some(session_id), &send_buffer[..used], false) - .await?; + await_or!(self.send_message(Some(session_id), &send_buffer[..used], false))?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(Some(session_id), &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(Some(session_id), &mut receive_buffer, false))?; self.handle_spdm_end_session_response(session_id, &receive_buffer[..used]) } diff --git a/spdmlib/src/requester/finish_req.rs b/spdmlib/src/requester/finish_req.rs index c5c6eb10..f2c3a7d4 100644 --- a/spdmlib/src/requester/finish_req.rs +++ b/spdmlib/src/requester/finish_req.rs @@ -9,19 +9,19 @@ use crate::protocol::*; use crate::requester::*; extern crate alloc; use alloc::boxed::Box; +use async_or::async_or; +use async_or::await_or; impl RequesterContext { - pub async fn send_receive_spdm_finish( + #[async_or] + pub fn send_receive_spdm_finish( &mut self, req_slot_id: Option, session_id: u32, ) -> SpdmResult { info!("send spdm finish\n"); - if let Err(e) = self - .delegate_send_receive_spdm_finish(req_slot_id, session_id) - .await - { + if let Err(e) = await_or!(self.delegate_send_receive_spdm_finish(req_slot_id, session_id)) { if let Some(session) = self.common.get_session_via_id(session_id) { session.teardown(); } @@ -32,7 +32,8 @@ impl RequesterContext { } } - pub async fn delegate_send_receive_spdm_finish( + #[async_or] + pub fn delegate_send_receive_spdm_finish( &mut self, req_slot_id: Option, session_id: u32, @@ -81,11 +82,9 @@ impl RequesterContext { } let send_used = res.unwrap(); let res = if in_clear_text { - self.send_message(None, &send_buffer[..send_used], false) - .await + await_or!(self.send_message(None, &send_buffer[..send_used], false)) } else { - self.send_message(Some(session_id), &send_buffer[..send_used], false) - .await + await_or!(self.send_message(Some(session_id), &send_buffer[..send_used], false)) }; if res.is_err() { self.common @@ -97,10 +96,9 @@ impl RequesterContext { let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let res = if in_clear_text { - self.receive_message(None, &mut receive_buffer, false).await + await_or!(self.receive_message(None, &mut receive_buffer, false)) } else { - self.receive_message(Some(session_id), &mut receive_buffer, false) - .await + await_or!(self.receive_message(Some(session_id), &mut receive_buffer, false)) }; if res.is_err() { self.common diff --git a/spdmlib/src/requester/get_capabilities_req.rs b/spdmlib/src/requester/get_capabilities_req.rs index e73c9d53..4d24c175 100644 --- a/spdmlib/src/requester/get_capabilities_req.rs +++ b/spdmlib/src/requester/get_capabilities_req.rs @@ -2,13 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD}; use crate::message::*; use crate::protocol::*; use crate::requester::*; impl RequesterContext { - pub async fn send_receive_spdm_capability(&mut self) -> SpdmResult { + #[async_or] + pub fn send_receive_spdm_capability(&mut self) -> SpdmResult { self.common.reset_buffer_via_request_code( SpdmRequestResponseCode::SpdmRequestGetCapabilities, None, @@ -16,13 +19,10 @@ impl RequesterContext { let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let send_used = self.encode_spdm_capability(&mut send_buffer)?; - self.send_message(None, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(None, &send_buffer[..send_used], false))?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(None, &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(None, &mut receive_buffer, false))?; self.handle_spdm_capability_response(0, &send_buffer[..send_used], &receive_buffer[..used]) } diff --git a/spdmlib/src/requester/get_certificate_req.rs b/spdmlib/src/requester/get_certificate_req.rs index 76b55d2a..b39608b4 100644 --- a/spdmlib/src/requester/get_certificate_req.rs +++ b/spdmlib/src/requester/get_certificate_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::crypto::{self, is_root_certificate}; use crate::error::{ SpdmResult, SPDM_STATUS_CRYPTO_ERROR, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_CERT, @@ -12,7 +14,8 @@ use crate::protocol::*; use crate::requester::*; impl RequesterContext { - async fn send_receive_spdm_certificate_partial( + #[async_or] + fn send_receive_spdm_certificate_partial( &mut self, session_id: Option, slot_id: u8, @@ -25,13 +28,10 @@ impl RequesterContext { let send_used = self.encode_spdm_certificate_partial(slot_id, offset, length, &mut send_buffer)?; - self.send_message(session_id, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(session_id, &send_buffer[..send_used], false))?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(session_id, &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(session_id, &mut receive_buffer, false))?; self.handle_spdm_certificate_partial_response( session_id, @@ -171,7 +171,8 @@ impl RequesterContext { } } - pub async fn send_receive_spdm_certificate( + #[async_or] + pub fn send_receive_spdm_certificate( &mut self, session_id: Option, slot_id: u8, @@ -191,11 +192,10 @@ impl RequesterContext { self.common.peer_info.peer_cert_chain_temp = Some(SpdmCertChainBuffer::default()); while length != 0 { - let (portion_length, remainder_length) = self + let (portion_length, remainder_length) = await_or!(self .send_receive_spdm_certificate_partial( session_id, slot_id, total_size, offset, length, - ) - .await?; + ))?; if total_size == 0 { total_size = portion_length + remainder_length; } diff --git a/spdmlib/src/requester/get_digests_req.rs b/spdmlib/src/requester/get_digests_req.rs index 8b12887a..37f84fbd 100644 --- a/spdmlib/src/requester/get_digests_req.rs +++ b/spdmlib/src/requester/get_digests_req.rs @@ -2,12 +2,15 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD}; use crate::message::*; use crate::requester::*; impl RequesterContext { - pub async fn send_receive_spdm_digest(&mut self, session_id: Option) -> SpdmResult { + #[async_or] + pub fn send_receive_spdm_digest(&mut self, session_id: Option) -> SpdmResult { info!("send spdm digest\n"); self.common.reset_buffer_via_request_code( @@ -18,13 +21,10 @@ impl RequesterContext { let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let send_used = self.encode_spdm_digest(&mut send_buffer)?; - self.send_message(session_id, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(session_id, &send_buffer[..send_used], false))?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(session_id, &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(session_id, &mut receive_buffer, false))?; self.handle_spdm_digest_response( session_id, diff --git a/spdmlib/src/requester/get_measurements_req.rs b/spdmlib/src/requester/get_measurements_req.rs index 0137f89d..b473bf8e 100644 --- a/spdmlib/src/requester/get_measurements_req.rs +++ b/spdmlib/src/requester/get_measurements_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::crypto; #[cfg(feature = "hashed-transcript-data")] use crate::error::SPDM_STATUS_INVALID_STATE_LOCAL; @@ -16,7 +18,8 @@ use crate::requester::*; impl RequesterContext { #[allow(clippy::too_many_arguments)] - async fn send_receive_spdm_measurement_record( + #[async_or] + fn send_receive_spdm_measurement_record( &mut self, session_id: Option, measurement_attributes: SpdmMeasurementAttributes, @@ -30,17 +33,15 @@ impl RequesterContext { *transcript_meas = Some(ManagedBufferM::default()); } - let result = self - .delegate_send_receive_spdm_measurement_record( - session_id, - measurement_attributes, - measurement_operation, - content_changed, - spdm_measurement_record_structure, - transcript_meas, - slot_id, - ) - .await; + let result = await_or!(self.delegate_send_receive_spdm_measurement_record( + session_id, + measurement_attributes, + measurement_operation, + content_changed, + spdm_measurement_record_structure, + transcript_meas, + slot_id, + )); if let Err(e) = result { if e != SPDM_STATUS_NOT_READY_PEER { @@ -53,7 +54,8 @@ impl RequesterContext { } #[allow(clippy::too_many_arguments)] - async fn delegate_send_receive_spdm_measurement_record( + #[async_or] + fn delegate_send_receive_spdm_measurement_record( &mut self, session_id: Option, measurement_attributes: SpdmMeasurementAttributes, @@ -81,14 +83,11 @@ impl RequesterContext { slot_id, &mut send_buffer, )?; - self.send_message(session_id, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(session_id, &send_buffer[..send_used], false))?; // Receive let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(session_id, &mut receive_buffer, true) - .await?; + let used = await_or!(self.receive_message(session_id, &mut receive_buffer, true))?; self.handle_spdm_measurement_record_response( session_id, @@ -269,7 +268,8 @@ impl RequesterContext { } #[allow(clippy::too_many_arguments)] - pub async fn send_receive_spdm_measurement( + #[async_or] + pub fn send_receive_spdm_measurement( &mut self, session_id: Option, slot_id: u8, @@ -281,17 +281,15 @@ impl RequesterContext { spdm_measurement_record_structure: &mut SpdmMeasurementRecordStructure, // out transcript_meas: &mut Option, // out ) -> SpdmResult { - *out_total_number = self - .send_receive_spdm_measurement_record( - session_id, - spdm_measuremente_attributes, - measurement_operation, - content_changed, - spdm_measurement_record_structure, - transcript_meas, - slot_id, - ) - .await?; + *out_total_number = await_or!(self.send_receive_spdm_measurement_record( + session_id, + spdm_measuremente_attributes, + measurement_operation, + content_changed, + spdm_measurement_record_structure, + transcript_meas, + slot_id, + ))?; Ok(()) } diff --git a/spdmlib/src/requester/get_version_req.rs b/spdmlib/src/requester/get_version_req.rs index 474686e5..74479d65 100644 --- a/spdmlib/src/requester/get_version_req.rs +++ b/spdmlib/src/requester/get_version_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{ SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD, SPDM_STATUS_NEGOTIATION_FAIL, }; @@ -10,19 +12,17 @@ use crate::protocol::*; use crate::requester::*; impl RequesterContext { - pub async fn send_receive_spdm_version(&mut self) -> SpdmResult { + #[async_or] + pub fn send_receive_spdm_version(&mut self) -> SpdmResult { // reset context on get version request self.common.reset_context(); let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let send_used = self.encode_spdm_version(&mut send_buffer)?; - self.send_message(None, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(None, &send_buffer[..send_used], false))?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(None, &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(None, &mut receive_buffer, false))?; self.handle_spdm_version_response(0, &send_buffer[..send_used], &receive_buffer[..used]) } diff --git a/spdmlib/src/requester/heartbeat_req.rs b/spdmlib/src/requester/heartbeat_req.rs index 245e5167..64bbbeda 100644 --- a/spdmlib/src/requester/heartbeat_req.rs +++ b/spdmlib/src/requester/heartbeat_req.rs @@ -2,12 +2,15 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD}; use crate::message::*; use crate::requester::*; impl RequesterContext { - pub async fn send_receive_spdm_heartbeat(&mut self, session_id: u32) -> SpdmResult { + #[async_or] + pub fn send_receive_spdm_heartbeat(&mut self, session_id: u32) -> SpdmResult { info!("send spdm heartbeat\n"); self.common.reset_buffer_via_request_code( @@ -17,14 +20,11 @@ impl RequesterContext { let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let used = self.encode_spdm_heartbeat(&mut send_buffer)?; - self.send_message(Some(session_id), &send_buffer[..used], false) - .await?; + await_or!(self.send_message(Some(session_id), &send_buffer[..used], false))?; // Receive let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(Some(session_id), &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(Some(session_id), &mut receive_buffer, false))?; self.handle_spdm_heartbeat_response(session_id, &receive_buffer[..used]) } diff --git a/spdmlib/src/requester/key_exchange_req.rs b/spdmlib/src/requester/key_exchange_req.rs index c9b45ef2..75f32c9b 100644 --- a/spdmlib/src/requester/key_exchange_req.rs +++ b/spdmlib/src/requester/key_exchange_req.rs @@ -4,6 +4,8 @@ extern crate alloc; use alloc::boxed::Box; +use async_or::async_or; +use async_or::await_or; use core::ops::DerefMut; use crate::common::session::SpdmSession; @@ -27,7 +29,8 @@ use crate::message::*; use crate::protocol::{SpdmMeasurementSummaryHashType, SpdmSignatureStruct, SpdmVersion}; impl RequesterContext { - pub async fn send_receive_spdm_key_exchange( + #[async_or] + pub fn send_receive_spdm_key_exchange( &mut self, slot_id: u8, measurement_summary_hash_type: SpdmMeasurementSummaryHashType, @@ -50,14 +53,11 @@ impl RequesterContext { slot_id, measurement_summary_hash_type, )?; - self.send_message(None, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(None, &send_buffer[..send_used], false))?; // Receive let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let receive_used = self - .receive_message(None, &mut receive_buffer, false) - .await?; + let receive_used = await_or!(self.receive_message(None, &mut receive_buffer, false))?; let mut target_session_id = None; if let Err(e) = self.handle_spdm_key_exchange_response( diff --git a/spdmlib/src/requester/key_update_req.rs b/spdmlib/src/requester/key_update_req.rs index 49d3ea7d..3ce14c0b 100644 --- a/spdmlib/src/requester/key_update_req.rs +++ b/spdmlib/src/requester/key_update_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{ SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD, SPDM_STATUS_INVALID_PARAMETER, @@ -10,7 +12,8 @@ use crate::message::*; use crate::requester::*; impl RequesterContext { - async fn send_receive_spdm_key_update_op( + #[async_or] + fn send_receive_spdm_key_update_op( &mut self, session_id: u32, key_update_operation: SpdmKeyUpdateOperation, @@ -25,8 +28,7 @@ impl RequesterContext { let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let used = self.encode_spdm_key_update_op(key_update_operation, tag, &mut send_buffer)?; - self.send_message(Some(session_id), &send_buffer[..used], false) - .await?; + await_or!(self.send_message(Some(session_id), &send_buffer[..used], false))?; // update key let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; @@ -40,9 +42,7 @@ impl RequesterContext { let update_responder = key_update_operation == SpdmKeyUpdateOperation::SpdmUpdateAllKeys; session.create_data_secret_update(spdm_version_sel, update_requester, update_responder)?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(Some(session_id), &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(Some(session_id), &mut receive_buffer, false))?; self.handle_spdm_key_update_op_response( session_id, @@ -129,7 +129,8 @@ impl RequesterContext { } } - pub async fn send_receive_spdm_key_update( + #[async_or] + pub fn send_receive_spdm_key_update( &mut self, session_id: u32, key_update_operation: SpdmKeyUpdateOperation, @@ -139,13 +140,11 @@ impl RequesterContext { { return Err(SPDM_STATUS_INVALID_MSG_FIELD); } - self.send_receive_spdm_key_update_op(session_id, key_update_operation, 1) - .await?; - self.send_receive_spdm_key_update_op( + await_or!(self.send_receive_spdm_key_update_op(session_id, key_update_operation, 1))?; + await_or!(self.send_receive_spdm_key_update_op( session_id, SpdmKeyUpdateOperation::SpdmVerifyNewKey, 2, - ) - .await + )) } } diff --git a/spdmlib/src/requester/mutual_authenticate.rs b/spdmlib/src/requester/mutual_authenticate.rs index 3478f692..6944d094 100644 --- a/spdmlib/src/requester/mutual_authenticate.rs +++ b/spdmlib/src/requester/mutual_authenticate.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::{ error::{SpdmResult, SPDM_STATUS_INVALID_MSG_FIELD, SPDM_STATUS_INVALID_STATE_LOCAL}, message::SpdmKeyExchangeMutAuthAttributes, @@ -10,7 +12,8 @@ use crate::{ use super::RequesterContext; impl RequesterContext { - pub async fn session_based_mutual_authenticate(&mut self, session_id: u32) -> SpdmResult<()> { + #[async_or] + pub fn session_based_mutual_authenticate(&mut self, session_id: u32) -> SpdmResult<()> { self.common.construct_my_cert_chain()?; let spdm_session = self @@ -23,8 +26,7 @@ impl RequesterContext { SpdmKeyExchangeMutAuthAttributes::MUT_AUTH_REQ => Ok(()), SpdmKeyExchangeMutAuthAttributes::MUT_AUTH_REQ_WITH_ENCAP_REQUEST | SpdmKeyExchangeMutAuthAttributes::MUT_AUTH_REQ_WITH_GET_DIGESTS => { - self.get_encapsulated_request_response(session_id, mut_auth_requested) - .await + await_or!(self.get_encapsulated_request_response(session_id, mut_auth_requested)) } _ => Err(SPDM_STATUS_INVALID_MSG_FIELD), } diff --git a/spdmlib/src/requester/negotiate_algorithms_req.rs b/spdmlib/src/requester/negotiate_algorithms_req.rs index 5fb7def0..8fce8095 100644 --- a/spdmlib/src/requester/negotiate_algorithms_req.rs +++ b/spdmlib/src/requester/negotiate_algorithms_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{ SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD, SPDM_STATUS_NEGOTIATION_FAIL, }; @@ -11,7 +13,8 @@ use crate::protocol::*; use crate::requester::*; impl RequesterContext { - pub async fn send_receive_spdm_algorithm(&mut self) -> SpdmResult { + #[async_or] + pub fn send_receive_spdm_algorithm(&mut self) -> SpdmResult { self.common.reset_buffer_via_request_code( SpdmRequestResponseCode::SpdmRequestNegotiateAlgorithms, None, @@ -19,13 +22,10 @@ impl RequesterContext { let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let send_used = self.encode_spdm_algorithm(&mut send_buffer)?; - self.send_message(None, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(None, &send_buffer[..send_used], false))?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let used = self - .receive_message(None, &mut receive_buffer, false) - .await?; + let used = await_or!(self.receive_message(None, &mut receive_buffer, false))?; self.handle_spdm_algorithm_response(0, &send_buffer[..send_used], &receive_buffer[..used]) } diff --git a/spdmlib/src/requester/psk_exchange_req.rs b/spdmlib/src/requester/psk_exchange_req.rs index 8166fa1a..6b8ce887 100644 --- a/spdmlib/src/requester/psk_exchange_req.rs +++ b/spdmlib/src/requester/psk_exchange_req.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use config::MAX_SPDM_PSK_CONTEXT_SIZE; use crate::crypto; @@ -19,7 +21,8 @@ extern crate alloc; use core::ops::DerefMut; impl RequesterContext { - pub async fn send_receive_spdm_psk_exchange( + #[async_or] + pub fn send_receive_spdm_psk_exchange( &mut self, measurement_summary_hash_type: SpdmMeasurementSummaryHashType, psk_hint: Option<&SpdmPskHintStruct>, @@ -44,14 +47,11 @@ impl RequesterContext { &mut send_buffer, )?; - self.send_message(None, &send_buffer[..send_used], false) - .await?; + await_or!(self.send_message(None, &send_buffer[..send_used], false))?; // Receive let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let receive_used = self - .receive_message(None, &mut receive_buffer, false) - .await?; + let receive_used = await_or!(self.receive_message(None, &mut receive_buffer, false))?; let mut target_session_id = None; if let Err(e) = self.handle_spdm_psk_exchange_response( diff --git a/spdmlib/src/requester/psk_finish_req.rs b/spdmlib/src/requester/psk_finish_req.rs index 1b1f2159..374402b2 100644 --- a/spdmlib/src/requester/psk_finish_req.rs +++ b/spdmlib/src/requester/psk_finish_req.rs @@ -11,12 +11,14 @@ use crate::protocol::*; use crate::requester::*; extern crate alloc; use alloc::boxed::Box; +use async_or::{async_or, await_or}; impl RequesterContext { - pub async fn send_receive_spdm_psk_finish(&mut self, session_id: u32) -> SpdmResult { + #[async_or] + pub fn send_receive_spdm_psk_finish(&mut self, session_id: u32) -> SpdmResult { info!("send spdm psk_finish\n"); - if let Err(e) = self.delegate_send_receive_spdm_psk_finish(session_id).await { + if let Err(e) = await_or!(self.delegate_send_receive_spdm_psk_finish(session_id)) { if let Some(session) = self.common.get_session_via_id(session_id) { session.teardown(); } @@ -27,7 +29,8 @@ impl RequesterContext { } } - pub async fn delegate_send_receive_spdm_psk_finish(&mut self, session_id: u32) -> SpdmResult { + #[async_or] + pub fn delegate_send_receive_spdm_psk_finish(&mut self, session_id: u32) -> SpdmResult { if self.common.get_session_via_id(session_id).is_none() { return Err(SPDM_STATUS_INVALID_PARAMETER); } @@ -47,9 +50,7 @@ impl RequesterContext { return Err(res.err().unwrap()); } let send_used = res.unwrap(); - let res = self - .send_message(Some(session_id), &send_buffer[..send_used], false) - .await; + let res = await_or!(self.send_message(Some(session_id), &send_buffer[..send_used], false)); if res.is_err() { self.common .get_session_via_id(session_id) @@ -59,9 +60,7 @@ impl RequesterContext { } let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let res = self - .receive_message(Some(session_id), &mut receive_buffer, false) - .await; + let res = await_or!(self.receive_message(Some(session_id), &mut receive_buffer, false)); if res.is_err() { self.common .get_session_via_id(session_id) diff --git a/spdmlib/src/requester/vendor_req.rs b/spdmlib/src/requester/vendor_req.rs index 2c82591a..5cd8ef84 100644 --- a/spdmlib/src/requester/vendor_req.rs +++ b/spdmlib/src/requester/vendor_req.rs @@ -2,12 +2,15 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::{async_or, await_or}; + use crate::error::{SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD}; use crate::message::*; use crate::requester::*; impl RequesterContext { - pub async fn send_spdm_vendor_defined_request( + #[async_or] + pub fn send_spdm_vendor_defined_request( &mut self, session_id: Option, standard_id: RegistryOrStandardsBodyID, @@ -38,14 +41,11 @@ impl RequesterContext { }; let used = request.spdm_encode(&mut self.common, &mut writer)?; - self.send_message(session_id, &send_buffer[..used], false) - .await?; + await_or!(self.send_message(session_id, &send_buffer[..used], false))?; //receive let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let receive_used = self - .receive_message(session_id, &mut receive_buffer, false) - .await?; + let receive_used = await_or!(self.receive_message(session_id, &mut receive_buffer, false))?; self.handle_spdm_vendor_defined_respond(session_id, &receive_buffer[..receive_used]) } diff --git a/spdmlib/src/responder/context.rs b/spdmlib/src/responder/context.rs index a1ab2f5b..4caf87b2 100644 --- a/spdmlib/src/responder/context.rs +++ b/spdmlib/src/responder/context.rs @@ -10,6 +10,7 @@ use crate::error::{SpdmResult, SPDM_STATUS_INVALID_STATE_LOCAL, SPDM_STATUS_UNSU use crate::message::*; use crate::protocol::{SpdmRequestCapabilityFlags, SpdmResponseCapabilityFlags}; use crate::watchdog::{reset_watchdog, start_watchdog}; +use async_or::{async_or, await_or}; use codec::{Codec, Reader, Writer}; extern crate alloc; use core::ops::DerefMut; @@ -38,7 +39,8 @@ impl ResponderContext { } } - pub async fn send_message( + #[async_or] + pub fn send_message( &mut self, session_id: Option, send_buffer: &[u8], @@ -61,25 +63,21 @@ impl ResponderContext { let mut transport_buffer = [0u8; config::SENDER_BUFFER_SIZE]; let used = if let Some(session_id) = session_id { - self.common - .encode_secured_message( - session_id, - send_buffer, - &mut transport_buffer, - false, - is_app_message, - ) - .await? + await_or!(self.common.encode_secured_message( + session_id, + send_buffer, + &mut transport_buffer, + false, + is_app_message, + ))? } else { - self.common - .encap(send_buffer, &mut transport_buffer) - .await? + await_or!(self.common.encap(send_buffer, &mut transport_buffer))? }; { let mut device_io = self.common.device_io.lock(); let device_io: &mut (dyn SpdmDeviceIo + Send + Sync) = device_io.deref_mut(); - device_io.send(Arc::new(&transport_buffer[..used])).await?; + await_or!(device_io.send(Arc::new(&transport_buffer[..used])))?; } let opcode = send_buffer[1]; @@ -186,7 +184,8 @@ impl ResponderContext { Ok(()) } - pub async fn process_message( + #[async_or] + pub fn process_message( &mut self, crypto_request: bool, app_handle: usize, // interpreted/managed by User @@ -195,7 +194,7 @@ impl ResponderContext { let mut response_buffer = [0u8; MAX_SPDM_MSG_SIZE]; let mut writer = Writer::init(&mut response_buffer); - match self.receive_message(raw_packet, crypto_request).await { + match await_or!(self.receive_message(raw_packet, crypto_request)) { Ok((used, secured_message)) => { if secured_message { let mut read = Reader::init(&raw_packet[0..used]); @@ -220,12 +219,10 @@ impl ResponderContext { let mut transport_encap = self.common.transport_encap.lock(); let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = transport_encap.deref_mut(); - transport_encap - .decap_app( - Arc::new(&app_buffer[0..decode_size]), - Arc::new(Mutex::new(&mut spdm_buffer)), - ) - .await + await_or!(transport_encap.decap_app( + Arc::new(&app_buffer[0..decode_size]), + Arc::new(Mutex::new(&mut spdm_buffer)), + )) }; match decap_result { Err(_) => Err(used), @@ -252,10 +249,11 @@ impl ResponderContext { &mut writer, ); if let Some(send_buffer) = send_buffer { - if let Err(err) = self - .send_message(Some(session_id), send_buffer, false) - .await - { + if let Err(err) = await_or!(self.send_message( + Some(session_id), + send_buffer, + false + )) { Ok(Err(err)) } else { Ok(status) @@ -271,9 +269,11 @@ impl ResponderContext { &mut writer, ); if let Some(send_buffer) = send_buffer { - if let Err(err) = - self.send_message(Some(session_id), send_buffer, true).await - { + if let Err(err) = await_or!(self.send_message( + Some(session_id), + send_buffer, + true + )) { Ok(Err(err)) } else { Ok(status) @@ -288,7 +288,7 @@ impl ResponderContext { let (status, send_buffer) = self.dispatch_message(&raw_packet[0..used], &mut writer); if let Some(send_buffer) = send_buffer { - if let Err(err) = self.send_message(None, send_buffer, false).await { + if let Err(err) = await_or!(self.send_message(None, send_buffer, false)) { Ok(Err(err)) } else { Ok(status) @@ -306,7 +306,8 @@ impl ResponderContext { // whose value is not normal, will return Err to caller to handle the raw packet, // So can't swap transport_buffer and receive_buffer, even though it should be by // their name suggestion. (03.01.2022) - async fn receive_message( + #[async_or] + fn receive_message( &mut self, receive_buffer: &mut [u8], crypto_request: bool, @@ -324,22 +325,18 @@ impl ResponderContext { let used = { let mut device_io = self.common.device_io.lock(); let device_io: &mut (dyn SpdmDeviceIo + Send + Sync) = device_io.deref_mut(); - device_io - .receive(Arc::new(Mutex::new(receive_buffer)), timeout) - .await? + await_or!(device_io.receive(Arc::new(Mutex::new(receive_buffer)), timeout))? }; let (used, secured_message) = { let mut transport_encap = self.common.transport_encap.lock(); let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = transport_encap.deref_mut(); - transport_encap - .decap( - Arc::new(&receive_buffer[..used]), - Arc::new(Mutex::new(&mut transport_buffer)), - ) - .await - .map_err(|_| used)? + await_or!(transport_encap.decap( + Arc::new(&receive_buffer[..used]), + Arc::new(Mutex::new(&mut transport_buffer)), + )) + .map_err(|_| used)? }; receive_buffer[..used].copy_from_slice(&transport_buffer[..used]); diff --git a/tdisp/Cargo.toml b/tdisp/Cargo.toml index 69f61f8c..a5fb4510 100644 --- a/tdisp/Cargo.toml +++ b/tdisp/Cargo.toml @@ -19,5 +19,8 @@ bitflags = "1.2.1" spdmlib = { path = "../spdmlib", default-features = false, features = ["spdm-ring"]} conquer-once = { version = "0.3.2", default-features = false } spin = { version = "0.9.8" } +async-or = { path = "../async-or" } + [features] +async = ["spdmlib/async", "async-or/async"] diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_bind_p2p_stream_request.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_bind_p2p_stream_request.rs index 3f824652..3145d230 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_bind_p2p_stream_request.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_bind_p2p_stream_request.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -23,7 +25,8 @@ use crate::pci_tdisp::{ TdispRequestResponseCode, }; -pub async fn pci_tdisp_req_bind_p2p_stream_request( +#[async_or] +pub fn pci_tdisp_req_bind_p2p_stream_request( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -55,14 +58,13 @@ pub async fn pci_tdisp_req_bind_p2p_stream_request( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; if let Ok(tdisp_error) = RspTdispError::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_report.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_report.rs index e56bbd78..a7aa2012 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_report.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_report.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -26,7 +28,8 @@ use crate::pci_tdisp::MAX_PORTION_LENGTH; use crate::pci_tdisp::STANDARD_ID; use crate::pci_tdisp_requester::TdispVersion; -pub async fn pci_tdisp_req_get_device_interface_report( +#[async_or] +pub fn pci_tdisp_req_get_device_interface_report( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -66,14 +69,13 @@ pub async fn pci_tdisp_req_get_device_interface_report( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; if let Ok(tdisp_error) = RspTdispError::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_state.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_state.rs index 372bcb52..55c77de4 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_state.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_device_interface_state.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -22,7 +24,8 @@ use crate::pci_tdisp::TdispRequestResponseCode; use crate::pci_tdisp::TdispVersion; use crate::pci_tdisp::STANDARD_ID; -pub async fn pci_tdisp_req_get_device_interface_state( +#[async_or] +pub fn pci_tdisp_req_get_device_interface_state( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -52,14 +55,13 @@ pub async fn pci_tdisp_req_get_device_interface_state( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let rsp_device_interface_state = RspDeviceInterfaceState::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_capabilities.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_capabilities.rs index 5e8a392b..d15faacf 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_capabilities.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_capabilities.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -23,7 +25,8 @@ use crate::pci_tdisp::TdispVersion; use crate::pci_tdisp::STANDARD_ID; #[allow(clippy::too_many_arguments)] -pub async fn pci_tdisp_req_get_tdisp_capabilities( +#[async_or] +pub fn pci_tdisp_req_get_tdisp_capabilities( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -59,14 +62,13 @@ pub async fn pci_tdisp_req_get_tdisp_capabilities( .encode(&mut writer) .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let rsp_tdisp_capabilities = RspTdispCapabilities::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_version.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_version.rs index 7b1cf1c8..a3506785 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_version.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_get_tdisp_version.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -18,7 +20,8 @@ use crate::pci_tdisp::STANDARD_ID; use crate::pci_tdisp::{ReqGetTdispVersion, TdispVersion}; use crate::pci_tdisp_requester::InterfaceId; -pub async fn pci_tdisp_req_get_tdisp_version( +#[async_or] +pub fn pci_tdisp_req_get_tdisp_version( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -37,14 +40,13 @@ pub async fn pci_tdisp_req_get_tdisp_version( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let rsp_tdisp_version = RspTdispVersion::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_lock_interface_request.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_lock_interface_request.rs index 0af5982d..d7011b07 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_lock_interface_request.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_lock_interface_request.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -27,7 +29,8 @@ use crate::pci_tdisp::START_INTERFACE_NONCE_LEN; use crate::pci_tdisp_requester::TdispVersion; #[allow(clippy::too_many_arguments)] -pub async fn pci_tdisp_req_lock_interface_request( +#[async_or] +pub fn pci_tdisp_req_lock_interface_request( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -66,14 +69,13 @@ pub async fn pci_tdisp_req_lock_interface_request( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; if let Ok(tdisp_error) = RspTdispError::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_set_mmio_attribute_request.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_set_mmio_attribute_request.rs index a644f38e..0b07a49d 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_set_mmio_attribute_request.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_set_mmio_attribute_request.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -23,7 +25,8 @@ use crate::pci_tdisp::{ InterfaceId, TdispErrorCode, TdispMessageHeader, TdispRequestResponseCode, TdispVersion, }; -pub async fn pci_tdisp_req_set_mmio_attribute_request( +#[async_or] +pub fn pci_tdisp_req_set_mmio_attribute_request( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -55,14 +58,13 @@ pub async fn pci_tdisp_req_set_mmio_attribute_request( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; if let Ok(tdisp_error) = RspTdispError::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_start_interface_request.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_start_interface_request.rs index a52e2c37..889167f9 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_start_interface_request.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_start_interface_request.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -25,7 +27,8 @@ use crate::pci_tdisp::TdispVersion; use crate::pci_tdisp::STANDARD_ID; use crate::pci_tdisp::START_INTERFACE_NONCE_LEN; -pub async fn pci_tdisp_req_start_interface_request( +#[async_or] +pub fn pci_tdisp_req_start_interface_request( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -57,14 +60,13 @@ pub async fn pci_tdisp_req_start_interface_request( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; if let Ok(tdisp_error) = RspTdispError::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_stop_interface_request.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_stop_interface_request.rs index cb51a54a..a95b939b 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_stop_interface_request.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_stop_interface_request.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -21,7 +23,8 @@ use crate::pci_tdisp::TdispRequestResponseCode; use crate::pci_tdisp::TdispVersion; use crate::pci_tdisp::STANDARD_ID; -pub async fn pci_tdisp_req_stop_interface_request( +#[async_or] +pub fn pci_tdisp_req_stop_interface_request( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -49,14 +52,13 @@ pub async fn pci_tdisp_req_stop_interface_request( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; let rsp_stop_interface_response = RspStopInterfaceResponse::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_unbind_p2p_stream_request.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_unbind_p2p_stream_request.rs index c3ae268a..923e2b7b 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_unbind_p2p_stream_request.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_unbind_p2p_stream_request.rs @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 or MIT +use async_or::async_or; +use async_or::await_or; use codec::Codec; use codec::Writer; use spdmlib::error::SPDM_STATUS_BUFFER_FULL; @@ -22,7 +24,8 @@ use crate::pci_tdisp::{ InterfaceId, TdispErrorCode, TdispMessageHeader, TdispRequestResponseCode, TdispVersion, }; -pub async fn pci_tdisp_req_unbind_p2p_stream_request( +#[async_or] +pub fn pci_tdisp_req_unbind_p2p_stream_request( // IN spdm_requester: &mut RequesterContext, session_id: u32, @@ -54,14 +57,13 @@ pub async fn pci_tdisp_req_unbind_p2p_stream_request( .map_err(|_| SPDM_STATUS_BUFFER_FULL)? as u16; - let vendor_defined_rsp_payload_struct = spdm_requester + let vendor_defined_rsp_payload_struct = await_or!(spdm_requester .send_spdm_vendor_defined_request( Some(session_id), STANDARD_ID, vendor_id(), vendor_defined_req_payload_struct, - ) - .await?; + ))?; if let Ok(tdisp_error) = RspTdispError::read_bytes( &vendor_defined_rsp_payload_struct.vendor_defined_rsp_payload diff --git a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_vdm_request.rs b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_vdm_request.rs index 0f64aff9..7e1fb845 100644 --- a/tdisp/src/pci_tdisp_requester/pci_tdisp_req_vdm_request.rs +++ b/tdisp/src/pci_tdisp_requester/pci_tdisp_req_vdm_request.rs @@ -1,36 +1,36 @@ -// Copyright (c) 2023 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 or MIT - -use spdmlib::error::SPDM_STATUS_INVALID_PARAMETER; -use spdmlib::message::VendorDefinedRspPayloadStruct; -use spdmlib::{ - error::SpdmResult, message::VendorDefinedReqPayloadStruct, requester::RequesterContext, -}; - -use crate::pci_tdisp::vendor_id; -use crate::pci_tdisp::STANDARD_ID; -use crate::pci_tdisp::TDISP_PROTOCOL_ID; - -pub async fn pci_tdisp_req_vdm_request( - // IN - spdm_requester: &mut RequesterContext, - session_id: u32, - vendor_defined_req_payload_struct: VendorDefinedReqPayloadStruct, - // OUT -) -> SpdmResult { - if vendor_defined_req_payload_struct.req_length < 1 - || vendor_defined_req_payload_struct.vendor_defined_req_payload[0] != TDISP_PROTOCOL_ID - { - Err(SPDM_STATUS_INVALID_PARAMETER) - } else { - spdm_requester - .send_spdm_vendor_defined_request( - Some(session_id), - STANDARD_ID, - vendor_id(), - vendor_defined_req_payload_struct, - ) - .await - } -} +// Copyright (c) 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 or MIT + +use async_or::{async_or, await_or}; +use spdmlib::error::SPDM_STATUS_INVALID_PARAMETER; +use spdmlib::message::VendorDefinedRspPayloadStruct; +use spdmlib::{ + error::SpdmResult, message::VendorDefinedReqPayloadStruct, requester::RequesterContext, +}; + +use crate::pci_tdisp::vendor_id; +use crate::pci_tdisp::STANDARD_ID; +use crate::pci_tdisp::TDISP_PROTOCOL_ID; + +#[async_or] +pub fn pci_tdisp_req_vdm_request( + // IN + spdm_requester: &mut RequesterContext, + session_id: u32, + vendor_defined_req_payload_struct: VendorDefinedReqPayloadStruct, + // OUT +) -> SpdmResult { + if vendor_defined_req_payload_struct.req_length < 1 + || vendor_defined_req_payload_struct.vendor_defined_req_payload[0] != TDISP_PROTOCOL_ID + { + Err(SPDM_STATUS_INVALID_PARAMETER) + } else { + await_or!(spdm_requester.send_spdm_vendor_defined_request( + Some(session_id), + STANDARD_ID, + vendor_id(), + vendor_defined_req_payload_struct, + )) + } +} diff --git a/test/spdm-emu/Cargo.toml b/test/spdm-emu/Cargo.toml index 640f2b7a..a3579e1a 100644 --- a/test/spdm-emu/Cargo.toml +++ b/test/spdm-emu/Cargo.toml @@ -22,6 +22,7 @@ async-recursion = "1.0.4" spin = { version = "0.9.8" } tokio = { version = "1.30.0", features = ["full"] } executor = { path = "../../executor" } +async-or = { path = "../../async-or" } spdmlib_crypto_mbedtls = { path = "../../spdmlib_crypto_mbedtls", default-features = false, optional = true } @@ -32,5 +33,10 @@ mandatory-mut-auth = ["mut-auth", "spdmlib/mandatory-mut-auth"] spdm-ring = ["spdmlib/spdm-ring", "spdmlib/std"] spdm-mbedtls = ["spdmlib_crypto_mbedtls"] hashed-transcript-data = ["spdmlib/hashed-transcript-data", "spdmlib_crypto_mbedtls?/hashed-transcript-data"] -async-executor = [] -async-tokio = [] + +# when adding any other async runtime, please include "async" feature! +async-executor = ["async"] +async-tokio = ["async"] + +# below async feature is for internal usage only +async = ["spdmlib/async", "async-or/async", "mctp_transport/async", "pcidoe_transport/async"] \ No newline at end of file diff --git a/test/spdm-emu/src/async_runtime.rs b/test/spdm-emu/src/async_runtime.rs index 45122706..fc37f751 100644 --- a/test/spdm-emu/src/async_runtime.rs +++ b/test/spdm-emu/src/async_runtime.rs @@ -7,6 +7,7 @@ use alloc::boxed::Box; use core::{future::Future, pin::Pin}; // Run async task +#[cfg(feature = "async")] pub fn block_on(future: Pin + 'static + Send>>) -> T where T: Send + 'static, diff --git a/test/spdm-emu/src/lib.rs b/test/spdm-emu/src/lib.rs index 1266d5e1..c285dec5 100644 --- a/test/spdm-emu/src/lib.rs +++ b/test/spdm-emu/src/lib.rs @@ -4,6 +4,7 @@ #![forbid(unsafe_code)] +#[cfg(feature = "async")] pub mod async_runtime; pub mod crypto; pub mod crypto_callback; diff --git a/test/spdm-emu/src/socket_io_transport.rs b/test/spdm-emu/src/socket_io_transport.rs index b9c0c5b2..534a1461 100644 --- a/test/spdm-emu/src/socket_io_transport.rs +++ b/test/spdm-emu/src/socket_io_transport.rs @@ -5,7 +5,7 @@ use crate::spdm_emu::*; use std::net::TcpStream; -use async_trait::async_trait; +use async_or::{async_impl_or, await_or}; use spdmlib::common::SpdmDeviceIo; use spdmlib::config; use spdmlib::error::SpdmResult; @@ -35,9 +35,10 @@ impl SocketIoTransport { } } -#[async_trait] +#[async_impl_or] impl SpdmDeviceIo for SocketIoTransport { - async fn receive( + #[async_or] + fn receive( &mut self, read_buffer: Arc>, timeout: usize, @@ -48,7 +49,7 @@ impl SpdmDeviceIo for SocketIoTransport { let read_buffer = read_buffer.deref_mut(); if let Some((_, command, payload)) = - receive_message(self.data.clone(), &mut buffer[..], timeout).await + await_or!(receive_message(self.data.clone(), &mut buffer[..], timeout)) { // TBD: do we need this? // self.transport_type = transport_type; @@ -68,7 +69,8 @@ impl SpdmDeviceIo for SocketIoTransport { } } - async fn send(&mut self, buffer: Arc<&[u8]>) -> SpdmResult { + #[async_or] + fn send(&mut self, buffer: Arc<&[u8]>) -> SpdmResult { send_message( self.data.clone(), self.transport_type, @@ -78,7 +80,8 @@ impl SpdmDeviceIo for SocketIoTransport { Ok(()) } - async fn flush_all(&mut self) -> SpdmResult { + #[async_or] + fn flush_all(&mut self) -> SpdmResult { Ok(()) } } diff --git a/test/spdm-emu/src/spdm_emu.rs b/test/spdm-emu/src/spdm_emu.rs index 93fdc7f1..9c8349f9 100644 --- a/test/spdm-emu/src/spdm_emu.rs +++ b/test/spdm-emu/src/spdm_emu.rs @@ -5,6 +5,7 @@ use std::io::{Read, Write}; use std::net::TcpStream; +use async_or::async_or; use spin::Mutex; extern crate alloc; use alloc::sync::Arc; @@ -55,7 +56,8 @@ impl Codec for SpdmSocketHeader { } // u32 type, u32 command, usize, payload -pub async fn receive_message( +#[async_or] +pub fn receive_message( stream: Arc>, buffer: &mut [u8], _timeout: usize, diff --git a/test/spdm-requester-emu/Cargo.toml b/test/spdm-requester-emu/Cargo.toml index 40e67fbb..5bd7ba66 100644 --- a/test/spdm-requester-emu/Cargo.toml +++ b/test/spdm-requester-emu/Cargo.toml @@ -20,6 +20,7 @@ futures = { version = "0.3", default-features = false } spin = { version = "0.9.8" } tokio = { version = "1.30.0", features = ["full"] } executor = { path = "../../executor" } +async-or = { path = "../../async-or" } [features] default = ["spdm-emu/default", "async-executor"] @@ -27,5 +28,6 @@ mut-auth = ["spdm-emu/mut-auth"] spdm-ring = ["spdm-emu/spdm-ring"] spdm-mbedtls = ["spdm-emu/spdm-mbedtls"] hashed-transcript-data = ["spdm-emu/hashed-transcript-data"] -async-executor = ["spdm-emu/async-executor"] -async-tokio = ["spdm-emu/async-tokio"] \ No newline at end of file +async-executor = ["spdm-emu/async-executor", "async"] +async-tokio = ["spdm-emu/async-tokio", "async"] +async = ["async-or/async", "spdm-emu/async", "spdmlib/async", "mctp_transport/async", "pcidoe_transport/async", "idekm/async", "tdisp/async"] \ No newline at end of file diff --git a/test/spdm-requester-emu/src/main.rs b/test/spdm-requester-emu/src/main.rs index b56e20bd..5725b15c 100644 --- a/test/spdm-requester-emu/src/main.rs +++ b/test/spdm-requester-emu/src/main.rs @@ -4,6 +4,8 @@ #![forbid(unsafe_code)] +use async_or::async_or; +use async_or::await_or; use codec::Codec; use common::SpdmDeviceIo; use common::SpdmTransportEncap; @@ -22,7 +24,6 @@ use log::LevelFilter; use log::*; use simple_logger::SimpleLogger; -use spdm_emu::async_runtime::block_on; use spdm_emu::crypto_callback::SECRET_ASYM_IMPL_INSTANCE; use spdm_emu::secret_impl_sample::SECRET_PSK_IMPL_INSTANCE; use spdm_emu::EMU_STACK_SIZE; @@ -62,7 +63,8 @@ extern crate alloc; use alloc::sync::Arc; use core::ops::DerefMut; -async fn send_receive_hello( +#[async_or] +fn send_receive_hello( stream: Arc>, transport_encap: Arc>, transport_type: u32, @@ -72,14 +74,12 @@ async fn send_receive_hello( let mut transport_encap = transport_encap.lock(); let transport_encap = transport_encap.deref_mut(); - let used = transport_encap - .encap( - Arc::new(b"Client Hello!\0"), - Arc::new(Mutex::new(&mut payload[..])), - false, - ) - .await - .unwrap(); + let used = await_or!(transport_encap.encap( + Arc::new(b"Client Hello!\0"), + Arc::new(Mutex::new(&mut payload[..])), + false, + )) + .unwrap(); let _buffer_size = spdm_emu::spdm_emu::send_message( stream.clone(), @@ -88,13 +88,16 @@ async fn send_receive_hello( &payload[0..used], ); let mut buffer = [0u8; config::RECEIVER_BUFFER_SIZE]; - let (_transport_type, _command, _payload) = - spdm_emu::spdm_emu::receive_message(stream, &mut buffer[..], ST1) - .await - .unwrap(); + let (_transport_type, _command, _payload) = await_or!(spdm_emu::spdm_emu::receive_message( + stream, + &mut buffer[..], + ST1 + )) + .unwrap(); } -async fn send_receive_stop( +#[async_or] +fn send_receive_stop( stream: Arc>, transport_encap: Arc>, transport_type: u32, @@ -106,10 +109,12 @@ async fn send_receive_stop( let mut transport_encap = transport_encap.lock(); let transport_encap = transport_encap.deref_mut(); - let used = transport_encap - .encap(Arc::new(b""), Arc::new(Mutex::new(&mut payload[..])), false) - .await - .unwrap(); + let used = await_or!(transport_encap.encap( + Arc::new(b""), + Arc::new(Mutex::new(&mut payload[..])), + false + )) + .unwrap(); let _buffer_size = spdm_emu::spdm_emu::send_message( stream.clone(), @@ -118,13 +123,16 @@ async fn send_receive_stop( &payload[0..used], ); let mut buffer = [0u8; config::RECEIVER_BUFFER_SIZE]; - let (_transport_type, _command, _payload) = - spdm_emu::spdm_emu::receive_message(stream, &mut buffer[..], ST1) - .await - .unwrap(); + let (_transport_type, _command, _payload) = await_or!(spdm_emu::spdm_emu::receive_message( + stream, + &mut buffer[..], + ST1 + )) + .unwrap(); } -async fn test_spdm( +#[async_or] +fn test_spdm( socket_io_transport: Arc>, transport_encap: Arc>, ) { @@ -255,29 +263,23 @@ async fn test_spdm( ); let mut transcript_vca = None; - if context.init_connection(&mut transcript_vca).await.is_err() { + if await_or!(context.init_connection(&mut transcript_vca)).is_err() { panic!("init_connection failed!"); } - if context.send_receive_spdm_digest(None).await.is_err() { + if await_or!(context.send_receive_spdm_digest(None)).is_err() { panic!("send_receive_spdm_digest failed!"); } - if context - .send_receive_spdm_certificate(None, 0) - .await - .is_err() - { + if await_or!(context.send_receive_spdm_certificate(None, 0)).is_err() { panic!("send_receive_spdm_certificate failed!"); } - if context - .send_receive_spdm_challenge( - 0, - SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, - ) - .await - .is_err() + if await_or!(context.send_receive_spdm_challenge( + 0, + SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, + )) + .is_err() { panic!("send_receive_spdm_challenge failed!"); } @@ -287,19 +289,17 @@ async fn test_spdm( let mut content_changed = None; let mut transcript_meas = None; - if context - .send_receive_spdm_measurement( - None, - 0, - SpdmMeasurementAttributes::SIGNATURE_REQUESTED, - SpdmMeasurementOperation::SpdmMeasurementRequestAll, - &mut content_changed, - &mut total_number, - &mut spdm_measurement_record_structure, - &mut transcript_meas, - ) - .await - .is_err() + if await_or!(context.send_receive_spdm_measurement( + None, + 0, + SpdmMeasurementAttributes::SIGNATURE_REQUESTED, + SpdmMeasurementOperation::SpdmMeasurementRequestAll, + &mut content_changed, + &mut total_number, + &mut spdm_measurement_record_structure, + &mut transcript_meas, + )) + .is_err() { panic!("send_receive_spdm_measurement failed!"); } @@ -308,13 +308,11 @@ async fn test_spdm( panic!("get message_m from send_receive_spdm_measurement failed!"); } - let result = context - .start_session( - false, - 0, - SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, - ) - .await; + let result = await_or!(context.start_session( + false, + 0, + SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, + )); if let Ok(session_id) = result { info!("\nSession established ... session_id {:0x?}\n", session_id); info!("Key Information ...\n"); @@ -338,18 +336,13 @@ async fn test_spdm( response_direction.salt.as_ref() ); - if context - .send_receive_spdm_heartbeat(session_id) - .await - .is_err() - { + if await_or!(context.send_receive_spdm_heartbeat(session_id)).is_err() { panic!("send_receive_spdm_heartbeat failed"); } - if context - .send_receive_spdm_key_update(session_id, SpdmKeyUpdateOperation::SpdmUpdateAllKeys) - .await - .is_err() + if await_or!(context + .send_receive_spdm_key_update(session_id, SpdmKeyUpdateOperation::SpdmUpdateAllKeys)) + .is_err() { panic!("send_receive_spdm_key_update failed"); } @@ -357,19 +350,17 @@ async fn test_spdm( let mut content_changed = None; let mut transcript_meas = None; - if context - .send_receive_spdm_measurement( - Some(session_id), - 0, - SpdmMeasurementAttributes::SIGNATURE_REQUESTED, - SpdmMeasurementOperation::SpdmMeasurementQueryTotalNumber, - &mut content_changed, - &mut total_number, - &mut spdm_measurement_record_structure, - &mut transcript_meas, - ) - .await - .is_err() + if await_or!(context.send_receive_spdm_measurement( + Some(session_id), + 0, + SpdmMeasurementAttributes::SIGNATURE_REQUESTED, + SpdmMeasurementOperation::SpdmMeasurementQueryTotalNumber, + &mut content_changed, + &mut total_number, + &mut spdm_measurement_record_structure, + &mut transcript_meas, + )) + .is_err() { panic!("send_receive_spdm_measurement failed"); } @@ -378,38 +369,28 @@ async fn test_spdm( panic!("get VCA + message_m from send_receive_spdm_measurement failed!"); } - if context - .send_receive_spdm_digest(Some(session_id)) - .await - .is_err() - { + if await_or!(context.send_receive_spdm_digest(Some(session_id))).is_err() { panic!("send_receive_spdm_digest failed"); } - if context - .send_receive_spdm_certificate(Some(session_id), 0) - .await - .is_err() - { + if await_or!(context.send_receive_spdm_certificate(Some(session_id), 0)).is_err() { panic!("send_receive_spdm_certificate failed"); } - if context.end_session(session_id).await.is_err() { + if await_or!(context.end_session(session_id)).is_err() { panic!("end_session failed"); } } else { panic!("\nSession session_id not got\n"); } - let result = context - .start_session( - true, - 0, - SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, - ) - .await; + let result = await_or!(context.start_session( + true, + 0, + SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, + )); if let Ok(session_id) = result { - if context.end_session(session_id).await.is_err() { + if await_or!(context.end_session(session_id)).is_err() { panic!("\nSession session_id is err\n"); } } else { @@ -417,7 +398,8 @@ async fn test_spdm( } } -async fn test_idekm_tdisp( +#[async_or] +fn test_idekm_tdisp( socket_io_transport: Arc>, transport_encap: Arc>, ) { @@ -548,29 +530,23 @@ async fn test_idekm_tdisp( ); let mut transcript_vca = None; - if context.init_connection(&mut transcript_vca).await.is_err() { + if await_or!(context.init_connection(&mut transcript_vca)).is_err() { panic!("init_connection failed!"); } - if context.send_receive_spdm_digest(None).await.is_err() { + if await_or!(context.send_receive_spdm_digest(None)).is_err() { panic!("send_receive_spdm_digest failed!"); } - if context - .send_receive_spdm_certificate(None, 0) - .await - .is_err() - { + if await_or!(context.send_receive_spdm_certificate(None, 0)).is_err() { panic!("send_receive_spdm_certificate failed!"); } - if context - .send_receive_spdm_challenge( - 0, - SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, - ) - .await - .is_err() + if await_or!(context.send_receive_spdm_challenge( + 0, + SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, + )) + .is_err() { panic!("send_receive_spdm_challenge failed!"); } @@ -580,31 +556,27 @@ async fn test_idekm_tdisp( let mut content_changed = None; let mut transcript_meas = None; - if context - .send_receive_spdm_measurement( - None, - 0, - SpdmMeasurementAttributes::SIGNATURE_REQUESTED, - SpdmMeasurementOperation::SpdmMeasurementRequestAll, - &mut content_changed, - &mut total_number, - &mut spdm_measurement_record_structure, - &mut transcript_meas, - ) - .await - .is_err() + if await_or!(context.send_receive_spdm_measurement( + None, + 0, + SpdmMeasurementAttributes::SIGNATURE_REQUESTED, + SpdmMeasurementOperation::SpdmMeasurementRequestAll, + &mut content_changed, + &mut total_number, + &mut spdm_measurement_record_structure, + &mut transcript_meas, + )) + .is_err() { panic!("send_receive_spdm_measurement failed!"); } - let session_id = context - .start_session( - false, - 0, - SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, - ) - .await - .unwrap(); + let session_id = await_or!(context.start_session( + false, + 0, + SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeNone, + )) + .unwrap(); // ide_km test let mut idekm_req_context = IdekmReqContext; @@ -616,20 +588,18 @@ async fn test_idekm_tdisp( let mut max_port_index = 0u8; let mut ide_reg_block = [0u32; PCI_IDE_KM_IDE_REG_BLOCK_MAX_COUNT]; let mut ide_reg_block_cnt = 0usize; - idekm_req_context - .pci_ide_km_query( - &mut context, - session_id, - port_index, - &mut dev_func_num, - &mut bus_num, - &mut segment, - &mut max_port_index, - &mut ide_reg_block, - &mut ide_reg_block_cnt, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_query( + &mut context, + session_id, + port_index, + &mut dev_func_num, + &mut bus_num, + &mut segment, + &mut max_port_index, + &mut ide_reg_block, + &mut ide_reg_block_cnt, + )) + .unwrap(); // ide_km key_prog key set 0 | RX | PR let stream_id = 0u8; @@ -649,20 +619,18 @@ async fn test_idekm_tdisp( key_iv.iv[0] = 0; key_iv.iv[1] = 1; let mut kp_ack_status = KpAckStatus::default(); - idekm_req_context - .pci_ide_km_key_prog( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - key_iv, - &mut kp_ack_status, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_prog( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + key_iv, + &mut kp_ack_status, + )) + .unwrap(); if kp_ack_status != KpAckStatus::SUCCESS { panic!( "KEY_PROG at Key Set 0 | RX | PR failed with {:X?}", @@ -690,20 +658,18 @@ async fn test_idekm_tdisp( key_iv.iv[0] = 0; key_iv.iv[1] = 1; let mut kp_ack_status = KpAckStatus::default(); - idekm_req_context - .pci_ide_km_key_prog( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - key_iv, - &mut kp_ack_status, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_prog( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + key_iv, + &mut kp_ack_status, + )) + .unwrap(); if kp_ack_status != KpAckStatus::SUCCESS { panic!( "KEY_PROG at Key Set 0 | RX | NPR failed with {:X?}", @@ -731,20 +697,18 @@ async fn test_idekm_tdisp( key_iv.iv[0] = 0; key_iv.iv[1] = 1; let mut kp_ack_status = KpAckStatus::default(); - idekm_req_context - .pci_ide_km_key_prog( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - key_iv, - &mut kp_ack_status, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_prog( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + key_iv, + &mut kp_ack_status, + )) + .unwrap(); if kp_ack_status != KpAckStatus::SUCCESS { panic!( "KEY_PROG at Key Set 0 | RX | CPL failed with {:X?}", @@ -772,20 +736,18 @@ async fn test_idekm_tdisp( key_iv.iv[0] = 0; key_iv.iv[1] = 1; let mut kp_ack_status = KpAckStatus::default(); - idekm_req_context - .pci_ide_km_key_prog( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - key_iv, - &mut kp_ack_status, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_prog( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + key_iv, + &mut kp_ack_status, + )) + .unwrap(); if kp_ack_status != KpAckStatus::SUCCESS { panic!( "KEY_PROG at Key Set 0 | TX | PR failed with {:X?}", @@ -813,20 +775,18 @@ async fn test_idekm_tdisp( key_iv.iv[0] = 0; key_iv.iv[1] = 1; let mut kp_ack_status = KpAckStatus::default(); - idekm_req_context - .pci_ide_km_key_prog( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - key_iv, - &mut kp_ack_status, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_prog( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + key_iv, + &mut kp_ack_status, + )) + .unwrap(); if kp_ack_status != KpAckStatus::SUCCESS { panic!( "KEY_PROG at Key Set 0 | TX | NPR failed with {:X?}", @@ -854,20 +814,18 @@ async fn test_idekm_tdisp( key_iv.iv[0] = 0; key_iv.iv[1] = 1; let mut kp_ack_status = KpAckStatus::default(); - idekm_req_context - .pci_ide_km_key_prog( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - key_iv, - &mut kp_ack_status, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_prog( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + key_iv, + &mut kp_ack_status, + )) + .unwrap(); if kp_ack_status != KpAckStatus::SUCCESS { panic!( "KEY_PROG at Key Set 0 | TX | CPL failed with {:X?}", @@ -883,216 +841,192 @@ async fn test_idekm_tdisp( let key_direction = KEY_DIRECTION_RX; let key_sub_stream = KEY_SUB_STREAM_PR; let port_index = 0u8; - idekm_req_context - .pci_ide_km_key_set_go( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_go( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_GO at Key Set 0 | RX | PR!"); // ide_km key_set_go key set 0 | RX | NPR let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_RX; let key_sub_stream = KEY_SUB_STREAM_NPR; - idekm_req_context - .pci_ide_km_key_set_go( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_go( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_GO at Key Set 0 | RX | NPR!"); // ide_km key_set_go key set 0 | RX | CPL let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_RX; let key_sub_stream = KEY_SUB_STREAM_CPL; - idekm_req_context - .pci_ide_km_key_set_go( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_go( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_GO at Key Set 0 | RX | CPL!"); // ide_km key_set_go key set 0 | TX | PR let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_TX; let key_sub_stream = KEY_SUB_STREAM_PR; - idekm_req_context - .pci_ide_km_key_set_go( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_go( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_GO at Key Set 0 | TX | PR!"); // ide_km key_set_go key set 0 | TX | NPR let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_TX; let key_sub_stream = KEY_SUB_STREAM_NPR; - idekm_req_context - .pci_ide_km_key_set_go( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_go( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_GO at Key Set 0 | TX | NPR!"); // ide_km key_set_go key set 0 | TX | CPL let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_TX; let key_sub_stream = KEY_SUB_STREAM_CPL; - idekm_req_context - .pci_ide_km_key_set_go( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_go( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_GO at Key Set 0 | TX | CPL!"); // ide_km key_set_stop key set 0 | RX | PR let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_RX; let key_sub_stream = KEY_SUB_STREAM_PR; - idekm_req_context - .pci_ide_km_key_set_stop( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_stop( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_STOP at Key Set 0 | RX | PR!"); // ide_km key_set_stop key set 0 | RX | NPR let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_RX; let key_sub_stream = KEY_SUB_STREAM_NPR; - idekm_req_context - .pci_ide_km_key_set_stop( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_stop( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_STOP at Key Set 0 | RX | NPR!"); // ide_km key_set_stop key set 0 | RX | CPL let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_RX; let key_sub_stream = KEY_SUB_STREAM_CPL; - idekm_req_context - .pci_ide_km_key_set_stop( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_stop( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_STOP at Key Set 0 | RX | CPL!"); // ide_km key_set_stop key set 0 | TX | PR let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_TX; let key_sub_stream = KEY_SUB_STREAM_PR; - idekm_req_context - .pci_ide_km_key_set_stop( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_stop( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_STOP at Key Set 0 | TX | PR!"); // ide_km key_set_stop key set 0 | TX | NPR let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_TX; let key_sub_stream = KEY_SUB_STREAM_NPR; - idekm_req_context - .pci_ide_km_key_set_stop( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_stop( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_STOP at Key Set 0 | TX | NPR!"); // ide_km key_set_stop key set 0 | TX | CPL let key_set = KEY_SET_0; let key_direction = KEY_DIRECTION_TX; let key_sub_stream = KEY_SUB_STREAM_CPL; - idekm_req_context - .pci_ide_km_key_set_stop( - &mut context, - session_id, - stream_id, - key_set, - key_direction, - key_sub_stream, - port_index, - ) - .await - .unwrap(); + await_or!(idekm_req_context.pci_ide_km_key_set_stop( + &mut context, + session_id, + stream_id, + key_set, + key_direction, + key_sub_stream, + port_index, + )) + .unwrap(); println!("Successful KEY_SET_STOP at Key Set 0 | TX | CPL!"); // tdisp test @@ -1104,9 +1038,12 @@ async fn test_idekm_tdisp( }, }; - pci_tdisp_req_get_tdisp_version(&mut context, session_id, interface_id) - .await - .unwrap(); + await_or!(pci_tdisp_req_get_tdisp_version( + &mut context, + session_id, + interface_id + )) + .unwrap(); println!("Successful Get Tdisp Version!"); let tsm_caps = 0; @@ -1116,7 +1053,7 @@ async fn test_idekm_tdisp( let mut num_req_this = 0u8; let mut num_req_all = 0u8; let mut req_msgs_supported = [0u8; 16]; - pci_tdisp_req_get_tdisp_capabilities( + await_or!(pci_tdisp_req_get_tdisp_capabilities( &mut context, session_id, tsm_caps, @@ -1127,19 +1064,17 @@ async fn test_idekm_tdisp( &mut num_req_this, &mut num_req_all, &mut req_msgs_supported, - ) - .await + )) .unwrap(); println!("Successful Get Tdisp Capabilities!"); let mut tdi_state = TdiState::ERROR; - pci_tdisp_req_get_device_interface_state( + await_or!(pci_tdisp_req_get_device_interface_state( &mut context, session_id, interface_id, &mut tdi_state, - ) - .await + )) .unwrap(); assert_eq!(tdi_state, TdiState::CONFIG_UNLOCKED); println!("Successful Get Tdisp State: {:X?}!", tdi_state); @@ -1150,7 +1085,7 @@ async fn test_idekm_tdisp( let bind_p2p_address_mask = 0; let mut start_interface_nonce = [0u8; START_INTERFACE_NONCE_LEN]; let mut tdisp_error_code = None; - pci_tdisp_req_lock_interface_request( + await_or!(pci_tdisp_req_lock_interface_request( &mut context, session_id, interface_id, @@ -1160,8 +1095,7 @@ async fn test_idekm_tdisp( bind_p2p_address_mask, &mut start_interface_nonce, &mut tdisp_error_code, - ) - .await + )) .unwrap(); assert!(tdisp_error_code.is_none()); println!( @@ -1169,28 +1103,26 @@ async fn test_idekm_tdisp( start_interface_nonce ); - pci_tdisp_req_get_device_interface_state( + await_or!(pci_tdisp_req_get_device_interface_state( &mut context, session_id, interface_id, &mut tdi_state, - ) - .await + )) .unwrap(); assert_eq!(tdi_state, TdiState::CONFIG_LOCKED); println!("Successful Get Tdisp State: {:X?}!", tdi_state); let mut report = [0u8; MAX_DEVICE_REPORT_BUFFER]; let mut report_size = 0usize; - pci_tdisp_req_get_device_interface_report( + await_or!(pci_tdisp_req_get_device_interface_report( &mut context, session_id, interface_id, &mut report, &mut report_size, &mut tdisp_error_code, - ) - .await + )) .unwrap(); assert!(tdisp_error_code.is_none()); let tdi_report = TdiReportStructure::read_bytes(&report).unwrap(); @@ -1199,47 +1131,47 @@ async fn test_idekm_tdisp( tdi_report ); - pci_tdisp_req_start_interface_request( + await_or!(pci_tdisp_req_start_interface_request( &mut context, session_id, interface_id, &start_interface_nonce, &mut tdisp_error_code, - ) - .await + )) .unwrap(); assert!(tdisp_error_code.is_none()); println!("Successful Start Interface!"); - pci_tdisp_req_get_device_interface_state( + await_or!(pci_tdisp_req_get_device_interface_state( &mut context, session_id, interface_id, &mut tdi_state, - ) - .await + )) .unwrap(); assert_eq!(tdi_state, TdiState::RUN); println!("Successful Get Tdisp State: {:X?}!", tdi_state); - pci_tdisp_req_stop_interface_request(&mut context, session_id, interface_id) - .await - .unwrap(); + await_or!(pci_tdisp_req_stop_interface_request( + &mut context, + session_id, + interface_id + )) + .unwrap(); println!("Successful Stop Interface!"); - pci_tdisp_req_get_device_interface_state( + await_or!(pci_tdisp_req_get_device_interface_state( &mut context, session_id, interface_id, &mut tdi_state, - ) - .await + )) .unwrap(); assert_eq!(tdi_state, TdiState::CONFIG_UNLOCKED); println!("Successful Get Tdisp State: {:X?}!", tdi_state); // end spdm session - context.end_session(session_id).await.unwrap(); + await_or!(context.end_session(session_id)).unwrap(); } // A new logger enables the user to choose log level by setting a `SPDM_LOG` environment variable. @@ -1293,31 +1225,49 @@ fn emu_main() { SOCKET_TRANSPORT_TYPE_MCTP }; - block_on(Box::pin(send_receive_hello( - socket.clone(), - transport_encap.clone(), - transport_type, - ))); + #[cfg(feature = "async")] + { + spdm_emu::async_runtime::block_on(Box::pin(send_receive_hello( + socket.clone(), + transport_encap.clone(), + transport_type, + ))); + + let socket_io_transport = SocketIoTransport::new(socket.clone()); + let socket_io_transport: Arc> = + Arc::new(Mutex::new(socket_io_transport)); + + spdm_emu::async_runtime::block_on(Box::pin(test_spdm( + socket_io_transport.clone(), + transport_encap.clone(), + ))); + + spdm_emu::async_runtime::block_on(Box::pin(test_idekm_tdisp( + socket_io_transport.clone(), + transport_encap.clone(), + ))); + + spdm_emu::async_runtime::block_on(Box::pin(send_receive_stop( + socket, + transport_encap, + transport_type, + ))); + } + + #[cfg(not(feature = "async"))] + { + send_receive_hello(socket.clone(), transport_encap.clone(), transport_type); - let socket_io_transport = SocketIoTransport::new(socket.clone()); - let socket_io_transport: Arc> = - Arc::new(Mutex::new(socket_io_transport)); + let socket_io_transport = SocketIoTransport::new(socket.clone()); + let socket_io_transport: Arc> = + Arc::new(Mutex::new(socket_io_transport)); - block_on(Box::pin(test_spdm( - socket_io_transport.clone(), - transport_encap.clone(), - ))); + test_spdm(socket_io_transport.clone(), transport_encap.clone()); - block_on(Box::pin(test_idekm_tdisp( - socket_io_transport.clone(), - transport_encap.clone(), - ))); + test_idekm_tdisp(socket_io_transport.clone(), transport_encap.clone()); - block_on(Box::pin(send_receive_stop( - socket, - transport_encap, - transport_type, - ))); + send_receive_stop(socket, transport_encap, transport_type); + } } fn main() { diff --git a/test/spdm-responder-emu/Cargo.toml b/test/spdm-responder-emu/Cargo.toml index a77c7840..fffaf484 100644 --- a/test/spdm-responder-emu/Cargo.toml +++ b/test/spdm-responder-emu/Cargo.toml @@ -20,13 +20,16 @@ futures = { version = "0.3", default-features = false } spin = { version = "0.9.8" } tokio = { version = "1.30.0", features = ["full"] } executor = { path = "../../executor" } +async-or = { path = "../../async-or" } zeroize = { version = "1.5.0", features = ["zeroize_derive"]} [features] +default = ["spdm-emu/default", "async-executor"] mut-auth = ["spdm-emu/mut-auth", "async-executor"] mandatory-mut-auth = ["mut-auth", "spdm-emu/mandatory-mut-auth"] spdm-ring = ["spdm-emu/spdm-ring"] spdm-mbedtls = ["spdm-emu/spdm-mbedtls"] hashed-transcript-data = ["spdm-emu/hashed-transcript-data"] -async-executor = ["spdm-emu/async-executor"] -async-tokio = ["spdm-emu/async-tokio"] +async-executor = ["spdm-emu/async-executor", "async"] +async-tokio = ["spdm-emu/async-tokio", "async"] +async = ["async-or/async", "spdm-emu/async", "spdmlib/async", "mctp_transport/async", "pcidoe_transport/async", "idekm/async", "tdisp/async"] \ No newline at end of file diff --git a/test/spdm-responder-emu/src/main.rs b/test/spdm-responder-emu/src/main.rs index 47a6bf24..93ccbe4d 100644 --- a/test/spdm-responder-emu/src/main.rs +++ b/test/spdm-responder-emu/src/main.rs @@ -1,8 +1,10 @@ // Copyright (c) 2020 Intel Corporation // // SPDX-License-Identifier: Apache-2.0 or MIT +#![feature(stmt_expr_attributes)] mod spdm_device_idekm_example; +use async_or::{async_or, await_or}; use idekm::pci_ide_km_responder::pci_ide_km_rsp_dispatcher; use idekm::pci_idekm::{vendor_id, IDEKM_PROTOCOL_ID}; use spdm_device_idekm_example::init_device_idekm_instance; @@ -12,7 +14,6 @@ use spdm_device_tdisp_example::init_device_tdisp_instance; use log::LevelFilter; use simple_logger::SimpleLogger; -use spdm_emu::async_runtime::block_on; use spdm_emu::watchdog_impl_sample::init_watchdog; use spdmlib::common::{SecuredMessageVersion, SpdmOpaqueSupport}; use spdmlib::config::{MAX_ROOT_CERT_SUPPORT, RECEIVER_BUFFER_SIZE}; @@ -52,7 +53,8 @@ use std::ops::Deref; use crate::spdm_device_tdisp_example::DeviceContext; -async fn process_socket_message( +#[async_or] +fn process_socket_message( stream: Arc>, transport_encap: Arc>, buffer: Arc>, @@ -74,25 +76,24 @@ async fn process_socket_message( match socket_header.command.to_be() { SOCKET_SPDM_COMMAND_TEST => { - send_hello(stream.clone(), transport_encap.clone(), res.0).await; + await_or!(send_hello(stream.clone(), transport_encap.clone(), res.0)); true } SOCKET_SPDM_COMMAND_STOP => { - send_stop(stream.clone(), transport_encap.clone(), res.0).await; + await_or!(send_stop(stream.clone(), transport_encap.clone(), res.0)); false } SOCKET_SPDM_COMMAND_NORMAL => true, _ => { if USE_PCIDOE { - send_pci_discovery( + await_or!(send_pci_discovery( stream.clone(), transport_encap.clone(), res.0, &buffer_ref[..buffer_size], - ) - .await + )) } else { - send_unknown(stream, transport_encap, res.0).await; + await_or!(send_unknown(stream, transport_encap, res.0)); false } } @@ -171,27 +172,56 @@ fn emu_main() { let mut need_continue; let raw_packet = [0u8; RECEIVER_BUFFER_SIZE]; let raw_packet = Arc::new(Mutex::new(raw_packet)); + loop { - let sz = block_on(Box::pin(handle_message( - stream.clone(), - if USE_PCIDOE { - pcidoe_transport_encap.clone() - } else { - mctp_transport_encap.clone() - }, - raw_packet.clone(), - ))); - - need_continue = block_on(Box::pin(process_socket_message( - stream.clone(), - if USE_PCIDOE { - pcidoe_transport_encap.clone() - } else { - mctp_transport_encap.clone() - }, - raw_packet.clone(), - sz, - ))); + #[cfg(feature = "async")] + { + let sz = spdm_emu::async_runtime::block_on(Box::pin(handle_message( + stream.clone(), + if USE_PCIDOE { + pcidoe_transport_encap.clone() + } else { + mctp_transport_encap.clone() + }, + raw_packet.clone(), + ))); + + need_continue = + spdm_emu::async_runtime::block_on(Box::pin(process_socket_message( + stream.clone(), + if USE_PCIDOE { + pcidoe_transport_encap.clone() + } else { + mctp_transport_encap.clone() + }, + raw_packet.clone(), + sz, + ))); + } + + #[cfg(not(feature = "async"))] + { + let sz = handle_message( + stream.clone(), + if USE_PCIDOE { + pcidoe_transport_encap.clone() + } else { + mctp_transport_encap.clone() + }, + raw_packet.clone(), + ); + + need_continue = process_socket_message( + stream.clone(), + if USE_PCIDOE { + pcidoe_transport_encap.clone() + } else { + mctp_transport_encap.clone() + }, + raw_packet.clone(), + sz, + ); + } if !need_continue { // TBD: return or break?? @@ -201,7 +231,8 @@ fn emu_main() { } } -async fn handle_message( +#[async_or] +fn handle_message( stream: Arc>, transport_encap: Arc>, raw_packet: Arc>, @@ -329,7 +360,7 @@ async fn handle_message( let mut raw_packet = raw_packet.lock(); let raw_packet = raw_packet.deref_mut(); raw_packet.zeroize(); - let res = context.process_message(false, 0, raw_packet).await; + let res = await_or!(context.process_message(false, 0, raw_packet)); match res { Ok(spdm_result) => match spdm_result { Ok(_) => continue, @@ -342,7 +373,8 @@ async fn handle_message( } } -pub async fn send_hello( +#[async_or] +pub fn send_hello( stream: Arc>, transport_encap: Arc>, tranport_type: u32, @@ -354,14 +386,12 @@ pub async fn send_hello( let mut transport_encap = transport_encap.lock(); let transport_encap = transport_encap.deref_mut(); - let used = transport_encap - .encap( - Arc::new(b"Server Hello!\0"), - Arc::new(Mutex::new(&mut payload[..])), - false, - ) - .await - .unwrap(); + let used = await_or!(transport_encap.encap( + Arc::new(b"Server Hello!\0"), + Arc::new(Mutex::new(&mut payload[..])), + false, + )) + .unwrap(); let _buffer_size = spdm_emu::spdm_emu::send_message( stream, @@ -371,7 +401,8 @@ pub async fn send_hello( ); } -pub async fn send_unknown( +#[async_or] +pub fn send_unknown( stream: Arc>, transport_encap: Arc>, transport_type: u32, @@ -381,10 +412,12 @@ pub async fn send_unknown( let mut payload = [0u8; 1024]; let mut transport_encap = transport_encap.lock(); let transport_encap = transport_encap.deref_mut(); - let used = transport_encap - .encap(Arc::new(b""), Arc::new(Mutex::new(&mut payload[..])), false) - .await - .unwrap(); + let used = await_or!(transport_encap.encap( + Arc::new(b""), + Arc::new(Mutex::new(&mut payload[..])), + false + )) + .unwrap(); let _buffer_size = spdm_emu::spdm_emu::send_message( stream, @@ -394,7 +427,8 @@ pub async fn send_unknown( ); } -pub async fn send_stop( +#[async_or] +pub fn send_stop( stream: Arc>, _transport_encap: Arc>, transport_type: u32, @@ -409,7 +443,8 @@ pub async fn send_stop( ); } -pub async fn send_pci_discovery( +#[async_or] +pub fn send_pci_discovery( stream: Arc>, transport_encap: Arc>, transport_type: u32, @@ -452,7 +487,11 @@ pub async fn send_pci_discovery( }, } if unknown_message { - send_unknown(stream.clone(), transport_encap, transport_type).await; + await_or!(send_unknown( + stream.clone(), + transport_encap, + transport_type + )); return false; } diff --git a/test/spdmlib-test/Cargo.toml b/test/spdmlib-test/Cargo.toml index c37db38f..5a0a84ce 100644 --- a/test/spdmlib-test/Cargo.toml +++ b/test/spdmlib-test/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -spdmlib = { path = "../../spdmlib", default-features = false, features=["spdm-ring"] } +spdmlib = { path = "../../spdmlib", default-features = false, features=["spdm-ring", "async"] } codec = { path = "../../codec", features = ["alloc"] } log = "0.4.13" ring = { version = "0.17.6" } @@ -16,7 +16,7 @@ async-trait = "0.1.71" async-recursion = "1.0.4" spin = { version = "0.9.8" } executor = { path = "../../executor" } -pcidoe_transport = { path = "../../pcidoe_transport" } +pcidoe_transport = { path = "../../pcidoe_transport", features=["async"] } [dev-dependencies] env_logger = "*"