From 10affc035990f5406fa73a058a3a621105ea96e3 Mon Sep 17 00:00:00 2001 From: OuyangHang33 Date: Fri, 29 Mar 2024 09:26:28 +0800 Subject: [PATCH] Improve roll back keys method when met key_update fail Signed-off-by: OuyangHang33 --- spdmlib/src/common/session.rs | 98 +++++++++---------------- spdmlib/src/requester/key_update_req.rs | 55 ++++++++------ spdmlib/src/responder/key_update_rsp.rs | 22 ++++-- 3 files changed, 87 insertions(+), 88 deletions(-) diff --git a/spdmlib/src/common/session.rs b/spdmlib/src/common/session.rs index a1acadc..66890a1 100644 --- a/spdmlib/src/common/session.rs +++ b/spdmlib/src/common/session.rs @@ -572,6 +572,17 @@ impl SpdmSession { Ok(()) } + pub fn back_up_data_secret(&mut self) { + self.application_secret_backup.request_data_secret = + self.application_secret.request_data_secret.clone(); + self.application_secret_backup.request_direction = + self.application_secret.request_direction.clone(); + self.application_secret_backup.response_data_secret = + self.application_secret.response_data_secret.clone(); + self.application_secret_backup.response_direction = + self.application_secret.response_direction.clone(); + } + pub fn create_data_secret_update( &mut self, spdm_version: SpdmVersion, @@ -586,11 +597,6 @@ impl SpdmSession { let aead_algo = self.crypto_param.aead_algo; if update_requester { - self.application_secret_backup.request_data_secret = - self.application_secret.request_data_secret.clone(); - self.application_secret_backup.request_direction = - self.application_secret.request_direction.clone(); - self.application_secret.request_data_secret = if let Some(us) = self.key_schedule.derive_update_secret( spdm_version, @@ -635,11 +641,6 @@ impl SpdmSession { } if update_responder { - self.application_secret_backup.response_data_secret = - self.application_secret.response_data_secret.clone(); - self.application_secret_backup.response_direction = - self.application_secret.response_direction.clone(); - self.application_secret.response_data_secret = if let Some(us) = self.key_schedule.derive_update_secret( spdm_version, @@ -685,41 +686,32 @@ impl SpdmSession { Ok(()) } - pub fn activate_data_secret_update( - &mut self, - _spdm_version: SpdmVersion, - update_requester: bool, - update_responder: bool, - use_new_key: bool, - ) -> SpdmResult { - if !use_new_key { - if update_requester { - self.application_secret.request_data_secret = - self.application_secret_backup.request_data_secret.clone(); - self.application_secret.request_direction = - self.application_secret_backup.request_direction.clone(); - } - if update_responder { - self.application_secret.response_data_secret = - self.application_secret_backup.response_data_secret.clone(); - self.application_secret.response_direction = - self.application_secret_backup.response_direction.clone(); - } - } else { - if update_requester { - self.application_secret_backup.request_data_secret = - SpdmDirectionDataSecretStruct::default(); - self.application_secret_backup.request_direction = - SpdmSessionSecretParam::default(); - } - if update_responder { - self.application_secret_backup.response_data_secret = - SpdmDirectionDataSecretStruct::default(); - self.application_secret_backup.response_direction = - SpdmSessionSecretParam::default(); - } + pub fn roll_back_data_secret(&mut self, update_requester: bool, update_responder: bool) { + if update_requester { + self.application_secret.request_data_secret = + self.application_secret_backup.request_data_secret.clone(); + self.application_secret.request_direction = + self.application_secret_backup.request_direction.clone(); + } + if update_responder { + self.application_secret.response_data_secret = + self.application_secret_backup.response_data_secret.clone(); + self.application_secret.response_direction = + self.application_secret_backup.response_direction.clone(); + } + } + + pub fn zero_data_secret_backup(&mut self, update_requester: bool, update_responder: bool) { + if update_requester { + self.application_secret_backup.request_data_secret = + SpdmDirectionDataSecretStruct::default(); + self.application_secret_backup.request_direction = SpdmSessionSecretParam::default(); + } + if update_responder { + self.application_secret_backup.response_data_secret = + SpdmDirectionDataSecretStruct::default(); + self.application_secret_backup.response_direction = SpdmSessionSecretParam::default(); } - Ok(()) } pub fn generate_hmac_with_response_finished_key( @@ -1370,24 +1362,6 @@ mod tests_session { } } - #[test] - fn test_case0_activate_data_secret_update() { - let mut session = SpdmSession::default(); - let status = session - .activate_data_secret_update(SpdmVersion::SpdmVersion12, true, true, false) - .is_ok(); - assert!(status); - - let status = session - .activate_data_secret_update(SpdmVersion::SpdmVersion12, true, false, false) - .is_ok(); - assert!(status); - - let status = session - .activate_data_secret_update(SpdmVersion::SpdmVersion12, false, false, false) - .is_ok(); - assert!(status); - } #[test] fn test_case0_decode_msg() { let mut session = SpdmSession::default(); diff --git a/spdmlib/src/requester/key_update_req.rs b/spdmlib/src/requester/key_update_req.rs index 7aec510..3288c30 100644 --- a/spdmlib/src/requester/key_update_req.rs +++ b/spdmlib/src/requester/key_update_req.rs @@ -45,12 +45,24 @@ impl RequesterContext { .receive_message(Some(session_id), &mut receive_buffer, false) .await?; - self.handle_spdm_key_update_op_response( + let res = self.handle_spdm_key_update_op_response( session_id, update_requester, update_responder, &receive_buffer[..used], - ) + ); + if res.is_ok() { + res + } else { + let session = if let Some(s) = self.common.get_session_via_id(session_id) { + s + } else { + return Err(SPDM_STATUS_INVALID_PARAMETER); + }; + // Roll back all keys when meet key update operation response Error + session.roll_back_data_secret(true, true); + res + } } pub fn encode_spdm_key_update_op( @@ -90,7 +102,6 @@ impl RequesterContext { SpdmRequestResponseCode::SpdmResponseKeyUpdateAck => { let key_update_rsp = SpdmKeyUpdateResponsePayload::spdm_read(&mut self.common, &mut reader); - let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; let session = if let Some(s) = self.common.get_session_via_id(session_id) { s } else { @@ -98,21 +109,10 @@ impl RequesterContext { }; if let Some(key_update_rsp) = key_update_rsp { debug!("!!! key_update rsp : {:02x?}\n", key_update_rsp); - session.activate_data_secret_update( - spdm_version_sel, - update_requester, - update_responder, - true, - )?; Ok(()) } else { error!("!!! key_update : fail !!!\n"); - session.activate_data_secret_update( - spdm_version_sel, - update_requester, - update_responder, - false, - )?; + session.roll_back_data_secret(update_requester, update_responder); Err(SPDM_STATUS_INVALID_MSG_FIELD) } } @@ -141,13 +141,26 @@ impl RequesterContext { { return Err(SPDM_STATUS_INVALID_MSG_FIELD); } + // back up original keys + self.common + .get_session_via_id(session_id) + .unwrap() + .back_up_data_secret(); self.send_receive_spdm_key_update_op(session_id, key_update_operation, 1) .await?; - self.send_receive_spdm_key_update_op( - session_id, - SpdmKeyUpdateOperation::SpdmVerifyNewKey, - 2, - ) - .await + + let res = self + .send_receive_spdm_key_update_op( + session_id, + SpdmKeyUpdateOperation::SpdmVerifyNewKey, + 2, + ) + .await; + + self.common + .get_session_via_id(session_id) + .unwrap() + .zero_data_secret_backup(true, true); + res } } diff --git a/spdmlib/src/responder/key_update_rsp.rs b/spdmlib/src/responder/key_update_rsp.rs index 63d43a5..36411e3 100644 --- a/spdmlib/src/responder/key_update_rsp.rs +++ b/spdmlib/src/responder/key_update_rsp.rs @@ -5,6 +5,7 @@ use crate::common::SpdmCodec; use crate::error::SpdmResult; use crate::error::SPDM_STATUS_INVALID_MSG_FIELD; +use crate::error::SPDM_STATUS_INVALID_PARAMETER; use crate::error::SPDM_STATUS_INVALID_STATE_LOCAL; use crate::message::*; use crate::responder::*; @@ -54,6 +55,10 @@ impl ResponderContext { debug!("!!! key_update req : {:02x?}\n", key_update_req); } else { error!("!!! key_update req : fail !!!\n"); + self.common + .get_session_via_id(session_id) + .unwrap() + .roll_back_data_secret(true, true); self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer); return ( Err(SPDM_STATUS_INVALID_MSG_FIELD), @@ -63,20 +68,20 @@ impl ResponderContext { let key_update_req = key_update_req.unwrap(); let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session: &mut crate::common::session::SpdmSession = + self.common.get_session_via_id(session_id).unwrap(); match key_update_req.key_update_operation { SpdmKeyUpdateOperation::SpdmUpdateSingleKey => { let _ = session.create_data_secret_update(spdm_version_sel, true, false); } SpdmKeyUpdateOperation::SpdmUpdateAllKeys => { let _ = session.create_data_secret_update(spdm_version_sel, true, true); - let _ = session.activate_data_secret_update(spdm_version_sel, true, true, true); - } - SpdmKeyUpdateOperation::SpdmVerifyNewKey => { - let _ = session.activate_data_secret_update(spdm_version_sel, true, false, true); } + SpdmKeyUpdateOperation::SpdmVerifyNewKey => {} _ => { error!("!!! key_update req : fail !!!\n"); + // roll back respnder and requester Keys + session.roll_back_data_secret(true, true); self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer); return ( Err(SPDM_STATUS_INVALID_MSG_FIELD), @@ -99,6 +104,13 @@ impl ResponderContext { }; let res = response.spdm_encode(&mut self.common, writer); if res.is_err() { + // when sending KEY_UPDATE_ACK response fail, roll back respnder and requester Keys + let session = if let Some(s) = self.common.get_session_via_id(session_id) { + s + } else { + return (Err(SPDM_STATUS_INVALID_PARAMETER), None); + }; + session.roll_back_data_secret(true, true); self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return ( Err(SPDM_STATUS_INVALID_STATE_LOCAL),