diff --git a/spdmlib/src/common/key_schedule.rs b/spdmlib/src/common/key_schedule.rs index 16b26d1a..dc955c62 100644 --- a/spdmlib/src/common/key_schedule.rs +++ b/spdmlib/src/common/key_schedule.rs @@ -433,8 +433,8 @@ impl SpdmKeySchedule { let mut version = [0u8; 8]; version.copy_from_slice(SPDM_VERSION_VALUE); - version[SPDM_VERSION_VALUE_MAJOR_INDEX] = (spdm_version.get_u8() >> 4) + b'0'; - version[SPDM_VERSION_VALUE_MINOR_INDEX] = (spdm_version.get_u8() & 0x0F) + b'0'; + version[SPDM_VERSION_VALUE_MAJOR_INDEX] = (u8::from(spdm_version) >> 4) + b'0'; + version[SPDM_VERSION_VALUE_MINOR_INDEX] = (u8::from(spdm_version) & 0x0F) + b'0'; let mut writer = Writer::init(buffer); length.encode(&mut writer).ok()?; diff --git a/spdmlib/src/common/mod.rs b/spdmlib/src/common/mod.rs index 0fe5bc29..e9f6283a 100644 --- a/spdmlib/src/common/mod.rs +++ b/spdmlib/src/common/mod.rs @@ -422,7 +422,7 @@ impl SpdmContext { return Err(SPDM_STATUS_CRYPTO_ERROR); } - if spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if spdm_version_sel >= SpdmVersion::SpdmVersion12 { crypto::hash::hash_ctx_update( session.runtime_info.digest_context_l1l2.as_mut().unwrap(), message_a.as_ref(), @@ -443,9 +443,7 @@ impl SpdmContext { return Err(SPDM_STATUS_CRYPTO_ERROR); } - if self.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { crypto::hash::hash_ctx_update( self.runtime_info.digest_context_l1l2.as_mut().unwrap(), self.runtime_info.message_a.as_ref(), diff --git a/spdmlib/src/common/opaque.rs b/spdmlib/src/common/opaque.rs index 5b7cb388..188a5d1c 100644 --- a/spdmlib/src/common/opaque.rs +++ b/spdmlib/src/common/opaque.rs @@ -272,7 +272,7 @@ impl SpdmCodec for SecuredMessageGeneralOpaqueDataHeader { bytes: &mut Writer, ) -> Result { let mut cnt = 0usize; - if context.negotiate_info.spdm_version_sel.get_u8() < SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel < SpdmVersion::SpdmVersion12 { cnt += DMTF_SPEC_ID .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; @@ -302,7 +302,7 @@ impl SpdmCodec for SecuredMessageGeneralOpaqueDataHeader { ) -> Option { let total_elements: u8; - if context.negotiate_info.spdm_version_sel.get_u8() < SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel < SpdmVersion::SpdmVersion12 { let spec_id = u32::read(r)?; if spec_id != DMTF_SPEC_ID { return None; diff --git a/spdmlib/src/message/algorithm.rs b/spdmlib/src/message/algorithm.rs index c2d663c7..39d94074 100644 --- a/spdmlib/src/message/algorithm.rs +++ b/spdmlib/src/message/algorithm.rs @@ -29,7 +29,7 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { ) -> Result { let mut cnt = 0usize; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { cnt += self .alg_struct_count .encode(bytes) @@ -41,7 +41,7 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { cnt += 0u8.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param2 let mut length: u16 = 32; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { let alg_fixed_count = 2u8; length += ((2 + alg_fixed_count) * self.alg_struct_count) as u16; } @@ -52,7 +52,7 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { cnt += self .other_params_support .encode(bytes) @@ -79,7 +79,7 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { cnt += 0u16.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // reserved3 - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { for algo in self.alg_struct.iter().take(self.alg_struct_count as usize) { cnt += algo.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; } @@ -92,7 +92,7 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { r: &mut Reader, ) -> Option { let mut alg_struct_count = 0; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { alg_struct_count = u8::read(r)?; // param1 if alg_struct_count > 4 { return None; @@ -105,14 +105,13 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { let length = u16::read(r)?; let measurement_specification = SpdmMeasurementSpecification::read(r)?; - let other_params_support = if context.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { - SpdmOpaqueSupport::read(r)? - } else { - u8::read(r)?; - SpdmOpaqueSupport::default() - }; + let other_params_support = + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { + SpdmOpaqueSupport::read(r)? + } else { + u8::read(r)?; + SpdmOpaqueSupport::default() + }; let base_asym_algo = SpdmBaseAsymAlgo::read(r)?; let base_hash_algo = SpdmBaseHashAlgo::read(r)?; @@ -134,7 +133,7 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { u16::read(r)?; // reserved3 let mut alg_struct = gen_array_clone(SpdmAlgStruct::default(), 4); - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { let mut dhe_present = false; let mut aead_present = false; let mut req_asym_present = false; @@ -183,7 +182,7 @@ impl SpdmCodec for SpdmNegotiateAlgorithmsRequestPayload { // check length // let mut calc_length: u16 = 32; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { let alg_fixed_count = 2u8; calc_length += ((2 + alg_fixed_count) * alg_struct_count) as u16; } @@ -222,7 +221,7 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { ) -> Result { let mut cnt = 0usize; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { cnt += self .alg_struct_count .encode(bytes) @@ -234,7 +233,7 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { cnt += 0u8.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param2 let mut length: u16 = 36; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { let alg_fixed_count = 2u8; length += ((2 + alg_fixed_count) * self.alg_struct_count) as u16; } @@ -245,7 +244,7 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { cnt += self .other_params_selection .encode(bytes) @@ -276,7 +275,7 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { cnt += 0u16.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // reserved3 - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { for algo in self.alg_struct.iter().take(self.alg_struct_count as usize) { cnt += algo.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; } @@ -289,7 +288,7 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { r: &mut Reader, ) -> Option { let mut alg_struct_count = 0; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { alg_struct_count = u8::read(r)?; // param1 if alg_struct_count > 4 { return None; @@ -318,18 +317,17 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { return None; } - let other_params_selection = if context.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { - SpdmOpaqueSupport::read(r)? - } else { - u8::read(r)?; - SpdmOpaqueSupport::default() - }; + let other_params_selection = + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { + SpdmOpaqueSupport::read(r)? + } else { + u8::read(r)?; + SpdmOpaqueSupport::default() + }; if !other_params_selection.is_no_more_than_one_selected() { return None; } - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 && (context .negotiate_info .rsp_capabilities_sel @@ -451,7 +449,7 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { u16::read(r)?; // reserved3 let mut alg_struct = gen_array_clone(SpdmAlgStruct::default(), 4); - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { let mut dhe_present = false; let mut aead_present = false; let mut req_asym_present = false; @@ -581,7 +579,7 @@ impl SpdmCodec for SpdmAlgorithmsResponsePayload { } let mut calc_length: u16 = 36; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { let alg_fixed_count = 2u8; calc_length += ((2 + alg_fixed_count) * alg_struct_count) as u16; } diff --git a/spdmlib/src/message/capability.rs b/spdmlib/src/message/capability.rs index 43203c14..31e09837 100644 --- a/spdmlib/src/message/capability.rs +++ b/spdmlib/src/message/capability.rs @@ -27,7 +27,7 @@ impl SpdmCodec for SpdmGetCapabilitiesRequestPayload { cnt += 0u8.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param1 cnt += 0u8.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param2 - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { cnt += 0u8.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // reserved cnt += self .ct_exponent @@ -40,7 +40,7 @@ impl SpdmCodec for SpdmGetCapabilitiesRequestPayload { .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; } - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { cnt += self .data_transfer_size .encode(bytes) @@ -62,7 +62,7 @@ impl SpdmCodec for SpdmGetCapabilitiesRequestPayload { let mut ct_exponent = 0; let mut flags = SpdmRequestCapabilityFlags::default(); - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { u8::read(r)?; // reserved ct_exponent = u8::read(r)?; u16::read(r)?; // reserved2 @@ -111,8 +111,7 @@ impl SpdmCodec for SpdmGetCapabilitiesRequestPayload { return None; } - if context.negotiate_info.spdm_version_sel.get_u8() - == SpdmVersion::SpdmVersion11.get_u8() + if context.negotiate_info.spdm_version_sel == SpdmVersion::SpdmVersion11 && flags.contains(SpdmRequestCapabilityFlags::MUT_AUTH_CAP) && !flags.contains(SpdmRequestCapabilityFlags::ENCAP_CAP) { @@ -122,7 +121,7 @@ impl SpdmCodec for SpdmGetCapabilitiesRequestPayload { let mut data_transfer_size = 0; let mut max_spdm_msg_size = 0; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { data_transfer_size = u32::read(r)?; max_spdm_msg_size = u32::read(r)?; if data_transfer_size < 42 || max_spdm_msg_size < data_transfer_size { @@ -171,7 +170,7 @@ impl SpdmCodec for SpdmCapabilitiesResponsePayload { .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { cnt += self .data_transfer_size .encode(bytes) @@ -209,7 +208,7 @@ impl SpdmCodec for SpdmCapabilitiesResponsePayload { { return None; } - if context.negotiate_info.spdm_version_sel.get_u8() < SpdmVersion::SpdmVersion11.get_u8() { + if context.negotiate_info.spdm_version_sel < SpdmVersion::SpdmVersion11 { if !flags.contains(SpdmResponseCapabilityFlags::MEAS_CAP_SIG) { if flags.contains(SpdmResponseCapabilityFlags::CERT_CAP) != flags.contains(SpdmResponseCapabilityFlags::CHAL_CAP) @@ -269,13 +268,13 @@ impl SpdmCodec for SpdmCapabilitiesResponsePayload { return None; } } - if context.negotiate_info.spdm_version_sel.get_u8() == SpdmVersion::SpdmVersion11.get_u8() + if context.negotiate_info.spdm_version_sel == SpdmVersion::SpdmVersion11 && flags.contains(SpdmResponseCapabilityFlags::MUT_AUTH_CAP) && !flags.contains(SpdmResponseCapabilityFlags::ENCAP_CAP) { return None; } - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { if !flags.contains(SpdmResponseCapabilityFlags::CERT_CAP) && (flags.contains(SpdmResponseCapabilityFlags::ALIAS_CERT_CAP) || flags.contains(SpdmResponseCapabilityFlags::SET_CERT_CAP)) @@ -295,7 +294,7 @@ impl SpdmCodec for SpdmCapabilitiesResponsePayload { } } - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { let data_transfer_size = u32::read(r)?; let max_spdm_msg_size = u32::read(r)?; if data_transfer_size < 42 || max_spdm_msg_size < data_transfer_size { diff --git a/spdmlib/src/message/encapsulated.rs b/spdmlib/src/message/encapsulated.rs index 03e6c79b..4fb0c088 100644 --- a/spdmlib/src/message/encapsulated.rs +++ b/spdmlib/src/message/encapsulated.rs @@ -121,7 +121,7 @@ impl SpdmCodec for SpdmEncapsulatedResponseAckPayload { .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param2 - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { cnt += self .ack_request_id .encode(bytes) @@ -142,7 +142,7 @@ impl SpdmCodec for SpdmEncapsulatedResponseAckPayload { let payload_type = SpdmEncapsulatedResponseAckPayloadType::read(r)?; // param2 let mut ack_request_id = 0; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { ack_request_id = u8::read(r)?; let _ = u24::read(r)?; // reserved } diff --git a/spdmlib/src/message/key_exchange.rs b/spdmlib/src/message/key_exchange.rs index 89e26bcf..7c4f9fd9 100644 --- a/spdmlib/src/message/key_exchange.rs +++ b/spdmlib/src/message/key_exchange.rs @@ -48,7 +48,7 @@ impl SpdmCodec for SpdmKeyExchangeRequestPayload { .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { cnt += self .session_policy .encode(bytes) diff --git a/spdmlib/src/message/key_update_test.rs b/spdmlib/src/message/key_update_test.rs index 5aa72f37..d897eeef 100644 --- a/spdmlib/src/message/key_update_test.rs +++ b/spdmlib/src/message/key_update_test.rs @@ -18,7 +18,7 @@ fn test_key_update_struct() { // 1. Validate KeyUpdate request length is 4. let u8_slice = &mut [ - SpdmVersion::SpdmVersion11.get_u8(), + u8::from(SpdmVersion::SpdmVersion11), SpdmRequestResponseCode::SpdmRequestKeyUpdate.get_u8(), SpdmKeyUpdateOperation::SpdmUpdateSingleKey.get_u8(), 0u8, @@ -30,7 +30,7 @@ fn test_key_update_struct() { // 2. Validate KEY_UPDATE_ACK response length is 4. let u8_slice = &mut [ - SpdmVersion::SpdmVersion11.get_u8(), + u8::from(SpdmVersion::SpdmVersion11), SpdmRequestResponseCode::SpdmResponseKeyUpdateAck.get_u8(), SpdmKeyUpdateOperation::SpdmUpdateSingleKey.get_u8(), 0u8, @@ -42,7 +42,7 @@ fn test_key_update_struct() { // 3. Validate KEY_UPDATE operations equal to reserved value. Expactation, fail. let u8_slice = &mut [ - SpdmVersion::SpdmVersion11.get_u8(), + u8::from(SpdmVersion::SpdmVersion11), SpdmRequestResponseCode::SpdmRequestKeyUpdate.get_u8(), SpdmKeyUpdateOperation::SpdmUpdateSingleKey.get_u8(), 0u8, @@ -54,7 +54,7 @@ fn test_key_update_struct() { // 4. Validate KEY_UPDATE_ACK KEY_UPDATE operations equal to reserved value. Expectation, fail let u8_slice = &mut [ - SpdmVersion::SpdmVersion11.get_u8(), + u8::from(SpdmVersion::SpdmVersion11), SpdmRequestResponseCode::SpdmResponseKeyUpdateAck.get_u8(), SpdmKeyUpdateOperation::SpdmUpdateSingleKey.get_u8(), 0u8, diff --git a/spdmlib/src/message/measurement.rs b/spdmlib/src/message/measurement.rs index 571dc3bf..476c86e6 100644 --- a/spdmlib/src/message/measurement.rs +++ b/spdmlib/src/message/measurement.rs @@ -82,9 +82,7 @@ impl SpdmCodec for SpdmGetMeasurementsRequestPayload { .nonce .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; - if context.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion11.get_u8() - { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { cnt += self .slot_id .encode(bytes) @@ -108,9 +106,7 @@ impl SpdmCodec for SpdmGetMeasurementsRequestPayload { }; let slot_id = if measurement_attributes.contains(SpdmMeasurementAttributes::SIGNATURE_REQUESTED) { - if context.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion11.get_u8() - { + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 { u8::read(r)? } else { 0 @@ -157,14 +153,13 @@ impl SpdmCodec for SpdmMeasurementsResponsePayload { .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param1 } - if context.negotiate_info.spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8() + if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 && context.runtime_info.need_measurement_signature { cnt += (self.slot_id | self.content_changed.bits()) .encode(bytes) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param2 - } else if context.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion11.get_u8() + } else if context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion11 && context.runtime_info.need_measurement_signature { cnt += self diff --git a/spdmlib/src/protocol/version.rs b/spdmlib/src/protocol/version.rs index 55002e16..aa627345 100644 --- a/spdmlib/src/protocol/version.rs +++ b/spdmlib/src/protocol/version.rs @@ -2,21 +2,61 @@ // // SPDX-License-Identifier: Apache-2.0 -use codec::enum_builder; -use codec::{Codec, Reader, Writer}; +use codec::{Codec, EncodeErr}; +use core::convert::TryFrom; -enum_builder! { - @U8 - EnumName: SpdmVersion; - EnumVal{ - SpdmVersion10 => 0x10, - SpdmVersion11 => 0x11, - SpdmVersion12 => 0x12 - } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum SpdmVersion { + SpdmVersion10, + SpdmVersion11, + SpdmVersion12, } + impl Default for SpdmVersion { - fn default() -> SpdmVersion { - SpdmVersion::Unknown(0) + fn default() -> Self { + Self::SpdmVersion10 + } +} + +impl TryFrom for SpdmVersion { + type Error = (); + fn try_from(untrusted_spdm_version: u8) -> Result>::Error> { + if untrusted_spdm_version == 0x10 { + Ok(SpdmVersion::SpdmVersion10) + } else if untrusted_spdm_version == 0x11 { + Ok(SpdmVersion::SpdmVersion11) + } else if untrusted_spdm_version == 0x12 { + Ok(SpdmVersion::SpdmVersion12) + } else { + Err(()) + } + } +} + +impl From for u8 { + fn from(spdm_version: SpdmVersion) -> Self { + match spdm_version { + SpdmVersion::SpdmVersion10 => 0x10, + SpdmVersion::SpdmVersion11 => 0x11, + SpdmVersion::SpdmVersion12 => 0x12, + } + } +} + +impl From<&SpdmVersion> for u8 { + fn from(spdm_version: &SpdmVersion) -> Self { + u8::from(*spdm_version) + } +} + +impl Codec for SpdmVersion { + fn encode(&self, bytes: &mut codec::Writer<'_>) -> Result { + u8::from(self).encode(bytes) + } + + fn read(r: &mut codec::Reader<'_>) -> Option { + let spdm_version = u8::read(r)?; + Self::try_from(spdm_version).ok() } } diff --git a/spdmlib/src/requester/challenge_req.rs b/spdmlib/src/requester/challenge_req.rs index fa5f66b3..83299ef3 100644 --- a/spdmlib/src/requester/challenge_req.rs +++ b/spdmlib/src/requester/challenge_req.rs @@ -177,9 +177,7 @@ impl RequesterContext { let mut message_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_sign.reset_message(); message_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -247,9 +245,7 @@ impl RequesterContext { .ok_or(SPDM_STATUS_INVALID_PARAMETER)? .data_size as usize)]; - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_m1m2.reset_message(); message_m1m2 .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/requester/encap_req.rs b/spdmlib/src/requester/encap_req.rs index 730cd8c0..1a5839e0 100644 --- a/spdmlib/src/requester/encap_req.rs +++ b/spdmlib/src/requester/encap_req.rs @@ -31,9 +31,7 @@ impl RequesterContext { session_id: u32, mut_auth_requested: SpdmKeyExchangeMutAuthAttributes, ) -> SpdmResult { - if self.common.negotiate_info.spdm_version_sel.get_u8() - < SpdmVersion::SpdmVersion11.get_u8() - { + if self.common.negotiate_info.spdm_version_sel < SpdmVersion::SpdmVersion11 { return Err(SPDM_STATUS_UNSUPPORTED_CAP); } diff --git a/spdmlib/src/requester/finish_req.rs b/spdmlib/src/requester/finish_req.rs index 2e216a88..61f8cbf0 100644 --- a/spdmlib/src/requester/finish_req.rs +++ b/spdmlib/src/requester/finish_req.rs @@ -326,9 +326,7 @@ impl RequesterContext { .calc_req_transcript_hash(false, slot_id, true, session)?; let mut transcript_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { transcript_sign.reset_message(); transcript_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -365,9 +363,7 @@ impl RequesterContext { debug!("transcript_hash - {:02x?}", transcript_hash.as_ref()); let mut transcript_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { transcript_sign.reset_message(); transcript_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/requester/get_capabilities_req.rs b/spdmlib/src/requester/get_capabilities_req.rs index 24479218..cd5be864 100644 --- a/spdmlib/src/requester/get_capabilities_req.rs +++ b/spdmlib/src/requester/get_capabilities_req.rs @@ -74,8 +74,8 @@ impl RequesterContext { capabilities.ct_exponent; self.common.negotiate_info.rsp_capabilities_sel = capabilities.flags; - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() + if self.common.negotiate_info.spdm_version_sel + >= SpdmVersion::SpdmVersion12 { self.common.negotiate_info.req_data_transfer_size_sel = self.common.config_info.data_transfer_size; diff --git a/spdmlib/src/requester/get_measurements_req.rs b/spdmlib/src/requester/get_measurements_req.rs index e337917b..68740e35 100644 --- a/spdmlib/src/requester/get_measurements_req.rs +++ b/spdmlib/src/requester/get_measurements_req.rs @@ -118,8 +118,8 @@ impl RequesterContext { if let Some(measurements) = measurements { debug!("!!! measurements : {:02x?}\n", measurements); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() + if self.common.negotiate_info.spdm_version_sel + >= SpdmVersion::SpdmVersion12 { self.common.runtime_info.content_changed = measurements.content_changed; @@ -270,9 +270,7 @@ impl RequesterContext { .data_size as usize)]; let mut message_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_sign.reset_message(); message_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -309,9 +307,7 @@ impl RequesterContext { ) -> SpdmResult { let mut message_l1l2 = ManagedBufferL1L2::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { let message_a = self.common.runtime_info.message_a.clone(); message_l1l2 .append_message(message_a.as_ref()) @@ -361,9 +357,7 @@ impl RequesterContext { .ok_or(SPDM_STATUS_INVALID_PARAMETER)? .data_size as usize)]; - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_l1l2.reset_message(); message_l1l2 .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/requester/get_version_req.rs b/spdmlib/src/requester/get_version_req.rs index 512239a9..9f49f758 100644 --- a/spdmlib/src/requester/get_version_req.rs +++ b/spdmlib/src/requester/get_version_req.rs @@ -59,10 +59,9 @@ impl RequesterContext { mut versions, } = version; - versions - .sort_unstable_by(|a, b| b.version.get_u8().cmp(&a.version.get_u8())); + versions.sort_unstable_by(|a, b| b.version.cmp(&a.version)); - self.common.negotiate_info.spdm_version_sel = SpdmVersion::Unknown(0); + let mut negotiate_version: Option = None; for spdm_version_struct in versions.iter().take(version_number_entry_count as usize) @@ -73,26 +72,23 @@ impl RequesterContext { .spdm_version .contains(&spdm_version_struct.version) { - self.common.negotiate_info.spdm_version_sel = - spdm_version_struct.version; + negotiate_version = Some(spdm_version_struct.version); break; } } - match self.common.negotiate_info.spdm_version_sel { - SpdmVersion::Unknown(_) => { - debug!( - "Version negotiation failed! with given version list: {:?}", - versions - ); - return Err(SPDM_STATUS_NEGOTIATION_FAIL); - } - _ => { - debug!( - "Version negotiated: {:?}", - self.common.negotiate_info.spdm_version_sel - ); - } + if let Some(negotiate_version) = negotiate_version { + self.common.negotiate_info.spdm_version_sel = negotiate_version; + debug!( + "Version negotiated: {:?}", + self.common.negotiate_info.spdm_version_sel + ); + } else { + debug!( + "Version negotiation failed! with given version list: {:?}", + versions + ); + return Err(SPDM_STATUS_NEGOTIATION_FAIL); } // clear cache data diff --git a/spdmlib/src/requester/key_exchange_req.rs b/spdmlib/src/requester/key_exchange_req.rs index da6f974d..30ba176c 100644 --- a/spdmlib/src/requester/key_exchange_req.rs +++ b/spdmlib/src/requester/key_exchange_req.rs @@ -87,9 +87,7 @@ impl RequesterContext { debug!("!!! exchange data : {:02x?}\n", exchange); let mut opaque; - if self.common.negotiate_info.spdm_version_sel.get_u8() - < SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel < SpdmVersion::SpdmVersion12 { opaque = SpdmOpaqueStruct { data_size: crate::common::opaque::REQ_DMTF_OPAQUE_DATA_SUPPORT_VERSION_LIST_DSP0277 .len() as u16, @@ -450,9 +448,7 @@ impl RequesterContext { .data_size as usize)]; let mut message_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_sign.reset_message(); message_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -516,9 +512,7 @@ impl RequesterContext { None, )?; - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message.reset_message(); message .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/requester/psk_exchange_req.rs b/spdmlib/src/requester/psk_exchange_req.rs index e0ad7e3a..a4f0a394 100644 --- a/spdmlib/src/requester/psk_exchange_req.rs +++ b/spdmlib/src/requester/psk_exchange_req.rs @@ -74,9 +74,7 @@ impl RequesterContext { crypto::rand::get_random(&mut psk_context)?; let mut opaque; - if self.common.negotiate_info.spdm_version_sel.get_u8() - < SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel < SpdmVersion::SpdmVersion12 { opaque = SpdmOpaqueStruct { data_size: crate::common::opaque::REQ_DMTF_OPAQUE_DATA_SUPPORT_VERSION_LIST_DSP0277 .len() as u16, diff --git a/spdmlib/src/responder/capability_rsp.rs b/spdmlib/src/responder/capability_rsp.rs index 07691df8..fb82babe 100644 --- a/spdmlib/src/responder/capability_rsp.rs +++ b/spdmlib/src/responder/capability_rsp.rs @@ -34,7 +34,7 @@ impl ResponderContext { request_response_code: _, }) = message_header { - if version.get_u8() < SpdmVersion::SpdmVersion10.get_u8() { + if version < SpdmVersion::SpdmVersion10 { self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer); return; } @@ -61,9 +61,7 @@ impl ResponderContext { self.common.negotiate_info.rsp_capabilities_sel = self.common.config_info.rsp_capabilities; - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { self.common.negotiate_info.req_data_transfer_size_sel = get_capabilities.data_transfer_size; self.common.negotiate_info.req_max_spdm_msg_size_sel = diff --git a/spdmlib/src/responder/challenge_rsp.rs b/spdmlib/src/responder/challenge_rsp.rs index 6c7175b7..adffb011 100644 --- a/spdmlib/src/responder/challenge_rsp.rs +++ b/spdmlib/src/responder/challenge_rsp.rs @@ -200,9 +200,7 @@ impl ResponderContext { debug!("message_m1m2_hash - {:02x?}", message_m1m2_hash.as_ref()); let mut message_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_sign.reset_message(); message_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -250,9 +248,7 @@ impl ResponderContext { .ok_or(SPDM_STATUS_CRYPTO_ERROR)?; debug!("message_m1m2_hash - {:02x?}", message_m1m2_hash.as_ref()); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_m1m2.reset_message(); message_m1m2 .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/responder/encap_rsp.rs b/spdmlib/src/responder/encap_rsp.rs index 23eb3eaf..230b803b 100644 --- a/spdmlib/src/responder/encap_rsp.rs +++ b/spdmlib/src/responder/encap_rsp.rs @@ -126,9 +126,7 @@ impl ResponderContext { request_response_code: u8, writer: &mut Writer<'_>, ) -> SpdmResult { - if self.common.negotiate_info.spdm_version_sel.get_u8() - < SpdmVersion::SpdmVersion11.get_u8() - { + if self.common.negotiate_info.spdm_version_sel < SpdmVersion::SpdmVersion11 { self.write_spdm_error( SpdmErrorCode::SpdmErrorUnsupportedRequest, request_response_code, diff --git a/spdmlib/src/responder/error_rsp.rs b/spdmlib/src/responder/error_rsp.rs index 7263a9f8..279934b7 100644 --- a/spdmlib/src/responder/error_rsp.rs +++ b/spdmlib/src/responder/error_rsp.rs @@ -5,7 +5,6 @@ use crate::common::SpdmCodec; use crate::error::SpdmResult; use crate::message::*; -use crate::protocol::SpdmVersion; use crate::responder::*; impl ResponderContext { @@ -17,11 +16,7 @@ impl ResponderContext { ) { let error = SpdmMessage { header: SpdmMessageHeader { - version: if self.common.negotiate_info.spdm_version_sel.get_u8() == 0 { - SpdmVersion::SpdmVersion10 - } else { - self.common.negotiate_info.spdm_version_sel - }, + version: self.common.negotiate_info.spdm_version_sel, request_response_code: SpdmRequestResponseCode::SpdmResponseError, }, payload: SpdmMessagePayload::SpdmErrorResponse(SpdmErrorResponsePayload { diff --git a/spdmlib/src/responder/finish_rsp.rs b/spdmlib/src/responder/finish_rsp.rs index dece24ed..b2db052a 100644 --- a/spdmlib/src/responder/finish_rsp.rs +++ b/spdmlib/src/responder/finish_rsp.rs @@ -303,9 +303,7 @@ impl ResponderContext { .ok_or(SPDM_STATUS_INVALID_PARAMETER)? .data_size as usize)]; let mut transcript_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { transcript_sign.reset_message(); transcript_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -351,9 +349,7 @@ impl ResponderContext { .data_size as usize)]; let mut transcript_hash_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { transcript_hash_sign.reset_message(); transcript_hash_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/responder/key_exchange_rsp.rs b/spdmlib/src/responder/key_exchange_rsp.rs index 40138cbe..f9528618 100644 --- a/spdmlib/src/responder/key_exchange_rsp.rs +++ b/spdmlib/src/responder/key_exchange_rsp.rs @@ -52,7 +52,7 @@ impl ResponderContext { self.write_spdm_error(SpdmErrorCode::SpdmErrorVersionMismatch, 0, writer); return Err(SPDM_STATUS_INVALID_MSG_FIELD); } - if message_header.version.get_u8() < SpdmVersion::SpdmVersion11.get_u8() { + if message_header.version < SpdmVersion::SpdmVersion11 { self.write_spdm_error(SpdmErrorCode::SpdmErrorUnsupportedRequest, 0, writer); return Err(SPDM_STATUS_INVALID_MSG_FIELD); } @@ -121,8 +121,8 @@ impl ResponderContext { .get_secure_spdm_version() == local_version { - if self.common.negotiate_info.spdm_version_sel.get_u8() - < SpdmVersion::SpdmVersion12.get_u8() + if self.common.negotiate_info.spdm_version_sel + < SpdmVersion::SpdmVersion12 { return_opaque.data_size = crate::common::opaque::RSP_DMTF_OPAQUE_DATA_VERSION_SELECTION_DSP0277 @@ -459,9 +459,7 @@ impl ResponderContext { debug!("message_hash - {:02x?}", transcript_hash.as_ref()); let mut message_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_sign.reset_message(); message_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -508,9 +506,7 @@ impl ResponderContext { &session.runtime_info.message_k, None, )?; - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message.reset_message(); message .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/responder/measurement_rsp.rs b/spdmlib/src/responder/measurement_rsp.rs index eeefad4e..336fe380 100644 --- a/spdmlib/src/responder/measurement_rsp.rs +++ b/spdmlib/src/responder/measurement_rsp.rs @@ -164,13 +164,12 @@ impl ResponderContext { SpdmMeasurementRecordStructure::default() }; - let content_changed = if runtime_content_change_support - && (spdm_version_sel.get_u8() >= SpdmVersion::SpdmVersion12.get_u8()) - { - content_changed - } else { - SpdmMeasurementContentChanged::NOT_SUPPORTED - }; + let content_changed = + if runtime_content_change_support && (spdm_version_sel >= SpdmVersion::SpdmVersion12) { + content_changed + } else { + SpdmMeasurementContentChanged::NOT_SUPPORTED + }; let mut nonce = [0u8; SPDM_NONCE_SIZE]; let res = crypto::rand::get_random(&mut nonce); @@ -280,9 +279,7 @@ impl ResponderContext { let mut message_sign = ManagedBuffer12Sign::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_sign.reset_message(); message_sign .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) @@ -315,9 +312,7 @@ impl ResponderContext { session_id: Option, ) -> SpdmResult { let mut message_l1l2 = ManagedBufferL1L2::default(); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { let message_a = self.common.runtime_info.message_a.clone(); message_l1l2 .append_message(message_a.as_ref()) @@ -352,9 +347,7 @@ impl ResponderContext { debug!("message_l1l2_hash - {:02x?}", message_l1l2_hash.as_ref()); - if self.common.negotiate_info.spdm_version_sel.get_u8() - >= SpdmVersion::SpdmVersion12.get_u8() - { + if self.common.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion12 { message_l1l2.reset_message(); message_l1l2 .append_message(&SPDM_VERSION_1_2_SIGNING_PREFIX_CONTEXT) diff --git a/spdmlib/src/responder/psk_exchange_rsp.rs b/spdmlib/src/responder/psk_exchange_rsp.rs index a84ecb08..315d6cec 100644 --- a/spdmlib/src/responder/psk_exchange_rsp.rs +++ b/spdmlib/src/responder/psk_exchange_rsp.rs @@ -52,7 +52,7 @@ impl ResponderContext { self.write_spdm_error(SpdmErrorCode::SpdmErrorVersionMismatch, 0, writer); return Err(SPDM_STATUS_INVALID_MSG_FIELD); } - if message_header.version.get_u8() < SpdmVersion::SpdmVersion11.get_u8() { + if message_header.version < SpdmVersion::SpdmVersion11 { self.write_spdm_error(SpdmErrorCode::SpdmErrorUnsupportedRequest, 0, writer); return Err(SPDM_STATUS_INVALID_MSG_FIELD); } @@ -120,8 +120,8 @@ impl ResponderContext { .get_secure_spdm_version() == local_version { - if self.common.negotiate_info.spdm_version_sel.get_u8() - < SpdmVersion::SpdmVersion12.get_u8() + if self.common.negotiate_info.spdm_version_sel + < SpdmVersion::SpdmVersion12 { return_opaque.data_size = crate::common::opaque::RSP_DMTF_OPAQUE_DATA_VERSION_SELECTION_DSP0277