diff --git a/libsignal-service/src/account_manager.rs b/libsignal-service/src/account_manager.rs index 127054712..9f076ca67 100644 --- a/libsignal-service/src/account_manager.rs +++ b/libsignal-service/src/account_manager.rs @@ -14,13 +14,11 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::time::SystemTime; -use libsignal_protocol::keys::PublicKey; -use libsignal_protocol::{Context, StoreContext}; +use libsignal_protocol::PublicKey; use zkgroup::profiles::ProfileKey; pub struct AccountManager { - context: Context, service: Service, profile_key: Option<[u8; 32]>, } @@ -51,13 +49,8 @@ const PRE_KEY_MINIMUM: u32 = 10; const PRE_KEY_BATCH_SIZE: u32 = 100; impl AccountManager { - pub fn new( - context: Context, - service: Service, - profile_key: Option<[u8; 32]>, - ) -> Self { + pub fn new(service: Service, profile_key: Option<[u8; 32]>) -> Self { Self { - context, service, profile_key, } @@ -96,13 +89,11 @@ impl AccountManager { } let pre_keys = libsignal_protocol::generate_pre_keys( - &self.context, pre_keys_offset_id, PRE_KEY_BATCH_SIZE, )?; let identity_key_pair = store_context.identity_key_pair()?; let signed_pre_key = libsignal_protocol::generate_signed_pre_key( - &self.context, &identity_key_pair, next_signed_pre_key_id, SystemTime::now(), @@ -207,7 +198,7 @@ impl AccountManager { query.get("pub_key").ok_or(LinkError::InvalidPublicKey)?; let pub_key = base64::decode(&**pub_key) .map_err(|_e| LinkError::InvalidPublicKey)?; - let pub_key = PublicKey::decode_point(&self.context, &pub_key) + let pub_key = PublicKey::deserialize(&pub_key) .map_err(|_e| LinkError::InvalidPublicKey)?; let identity_key_pair = store_context.identity_key_pair()?; diff --git a/libsignal-service/src/cipher.rs b/libsignal-service/src/cipher.rs index dd4f87f20..0c81c6611 100644 --- a/libsignal-service/src/cipher.rs +++ b/libsignal-service/src/cipher.rs @@ -99,11 +99,11 @@ impl ServiceCipher { let sender = get_preferred_protocol_address( &self.store_context, envelope.source_address(), - envelope.source_device() as i32, + envelope.source_device(), )?; let metadata = Metadata { sender: envelope.source_address(), - sender_device: envelope.source_device() as i32, + sender_device: envelope.source_device(), timestamp: envelope.timestamp(), needs_receipt: false, }; @@ -130,11 +130,11 @@ impl ServiceCipher { let sender = get_preferred_protocol_address( &self.store_context, envelope.source_address(), - envelope.source_device() as i32, + envelope.source_device(), )?; let metadata = Metadata { sender: envelope.source_address(), - sender_device: envelope.source_device() as i32, + sender_device: envelope.source_device(), timestamp: envelope.timestamp(), needs_receipt: false, }; @@ -293,8 +293,8 @@ fn strip_padding( pub fn get_preferred_protocol_address( store_context: &StoreContext, address: ServiceAddress, - device_id: i32, -) -> Result { + device_id: u32, +) -> Result { if let Some(ref uuid) = address.uuid { let address = ProtocolAddress::new(uuid.to_string(), device_id as i32); if store_context.contains_session(&address)? { diff --git a/libsignal-service/src/configuration.rs b/libsignal-service/src/configuration.rs index 3c97a4bda..2a8ec4b6f 100644 --- a/libsignal-service/src/configuration.rs +++ b/libsignal-service/src/configuration.rs @@ -29,7 +29,7 @@ pub struct ServiceCredentials { pub phonenumber: phonenumber::PhoneNumber, pub password: Option, pub signaling_key: Option, - pub device_id: Option, + pub device_id: Option, } impl ServiceCredentials { diff --git a/libsignal-service/src/content.rs b/libsignal-service/src/content.rs index 850105563..87b083ea2 100644 --- a/libsignal-service/src/content.rs +++ b/libsignal-service/src/content.rs @@ -12,7 +12,7 @@ pub use crate::{ #[derive(Clone, Debug)] pub struct Metadata { pub sender: crate::ServiceAddress, - pub sender_device: i32, + pub sender_device: u32, pub timestamp: u64, pub needs_receipt: bool, } diff --git a/libsignal-service/src/groups_v2/utils.rs b/libsignal-service/src/groups_v2/utils.rs index bf874084d..98cc5889f 100644 --- a/libsignal-service/src/groups_v2/utils.rs +++ b/libsignal-service/src/groups_v2/utils.rs @@ -1,4 +1,4 @@ -use libsignal_protocol::{Context, Error}; +use libsignal_protocol::{error::SignalProtocolError, Context}; use zkgroup::groups::GroupMasterKey; use zkgroup::GROUP_MASTER_KEY_LEN; @@ -8,15 +8,11 @@ use zkgroup::GROUP_MASTER_KEY_LEN; pub fn derive_v2_migration_master_key( ctx: &Context, group_id: &[u8], -) -> Result { +) -> Result { assert_eq!(group_id.len(), 16, "Group ID must be exactly 16 bytes"); - let hkdf = libsignal_protocol::create_hkdf(ctx, 3)?; - let bytes = hkdf.derive_secrets( - GROUP_MASTER_KEY_LEN, - group_id, - &[], - b"GV2 Migration", - )?; + let hkdf = libsignal_protocol::HKDF::new(3)?; + let bytes = + hkdf.derive_secrets(group_id, b"GV2 Migration", GROUP_MASTER_KEY_LEN)?; let mut bytes_stack = [0u8; GROUP_MASTER_KEY_LEN]; bytes_stack.copy_from_slice(&bytes); Ok(GroupMasterKey::new(bytes_stack)) diff --git a/libsignal-service/src/provisioning.rs b/libsignal-service/src/provisioning.rs new file mode 100644 index 000000000..3eb42b147 --- /dev/null +++ b/libsignal-service/src/provisioning.rs @@ -0,0 +1,440 @@ +use aes::Aes256; +use block_modes::{block_padding::Pkcs7, BlockMode, Cbc}; +use bytes::{Bytes, BytesMut}; +use futures::{ + channel::mpsc::{self, Sender}, + prelude::*, + stream::FuturesUnordered, +}; +use hmac::{Hmac, Mac, NewMac}; +use pin_project::pin_project; +use prost::Message; +use rand::Rng; +use sha2::Sha256; +use url::Url; + +use libsignal_protocol::{ + keys::{KeyPair, PublicKey}, + Context, +}; + +pub use crate::proto::{ + ProvisionEnvelope, ProvisionMessage, ProvisioningVersion, +}; + +use crate::{ + envelope::{CIPHER_KEY_SIZE, IV_LENGTH, IV_OFFSET}, + messagepipe::{WebSocketService, WebSocketStreamItem}, + proto::{ + web_socket_message, ProvisioningUuid, WebSocketMessage, + WebSocketRequestMessage, WebSocketResponseMessage, + }, + push_service::ServiceError, +}; + +const VERSION: u8 = 1; + +#[derive(Debug)] +enum CipherMode { + Decrypt(KeyPair), + Encrypt(PublicKey), +} + +impl CipherMode { + fn public(&self) -> PublicKey { + match self { + CipherMode::Decrypt(pair) => pair.public(), + CipherMode::Encrypt(pub_key) => pub_key.clone(), + } + } +} + +#[derive(Debug)] +pub struct ProvisioningCipher { + ctx: Context, + key_material: CipherMode, +} + +#[derive(thiserror::Error, Debug)] +pub enum ProvisioningError { + #[error("Invalid provisioning data: {reason}")] + InvalidData { reason: String }, + #[error("Protobuf decoding error: {0}")] + DecodeError(#[from] prost::DecodeError), + #[error("Websocket error: {reason}")] + WsError { reason: String }, + #[error("Websocket closing: {reason}")] + WsClosing { reason: String }, + #[error("Service error: {0}")] + ServiceError(#[from] ServiceError), + #[error("libsignal-protocol error: {0}")] + ProtocolError(#[from] libsignal_protocol::error::SignalProtocolError), + #[error("ProvisioningCipher in encrypt-only mode")] + EncryptOnlyProvisioningCipher, +} + +impl ProvisioningCipher { + pub fn new(ctx: Context) -> Result { + let key_pair = libsignal_protocol::generate_key_pair(&ctx)?; + Ok(Self { + ctx, + key_material: CipherMode::Decrypt(key_pair), + }) + } + + pub fn from_public(ctx: Context, key: PublicKey) -> Self { + Self { + ctx, + key_material: CipherMode::Encrypt(key), + } + } + + pub fn from_key_pair(ctx: Context, key_pair: KeyPair) -> Self { + Self { + ctx, + key_material: CipherMode::Decrypt(key_pair), + } + } + + pub fn public_key(&self) -> PublicKey { + self.key_material.public() + } + + pub fn encrypt( + &self, + msg: ProvisionMessage, + ) -> Result { + let msg = { + let mut encoded = Vec::with_capacity(msg.encoded_len()); + msg.encode(&mut encoded).expect("infallible encoding"); + encoded + }; + + let mut rng = rand::thread_rng(); + let our_key_pair = libsignal_protocol::generate_key_pair(&self.ctx)?; + let agreement = self + .public_key() + .calculate_agreement(&our_key_pair.private())?; + let hkdf = libsignal_protocol::create_hkdf(&self.ctx, 3)?; + + let shared_secrets = hkdf.derive_secrets( + 64, + &agreement, + &[], + b"TextSecure Provisioning Message", + )?; + + let aes_key = &shared_secrets[0..32]; + let mac_key = &shared_secrets[32..]; + let iv: [u8; IV_LENGTH] = rng.gen(); + + let cipher = Cbc::::new_var(&aes_key, &iv) + .expect("initalization of CBC/AES/PKCS7"); + let ciphertext = cipher.encrypt_vec(&msg); + let mut mac = Hmac::::new_varkey(&mac_key) + .expect("HMAC can take any size key"); + mac.update(&[VERSION]); + mac.update(&iv); + mac.update(&ciphertext); + let mac = mac.finalize().into_bytes(); + + let body: Vec = std::iter::once(VERSION) + .chain(iv.iter().cloned()) + .chain(ciphertext) + .chain(mac) + .collect(); + + Ok(ProvisionEnvelope { + public_key: Some( + our_key_pair.public().to_bytes()?.as_slice().to_vec(), + ), + body: Some(body), + }) + } + + pub fn decrypt( + &self, + provision_envelope: ProvisionEnvelope, + ) -> Result { + let key_pair = match self.key_material { + CipherMode::Decrypt(ref key_pair) => key_pair, + CipherMode::Encrypt(_) => { + return Err(ProvisioningError::EncryptOnlyProvisioningCipher); + } + }; + let master_ephemeral = PublicKey::decode_point( + &self.ctx, + &provision_envelope.public_key.expect("no public key"), + )?; + let body = provision_envelope + .body + .expect("no body in ProvisionMessage"); + if body[0] != VERSION { + return Err(ProvisioningError::InvalidData { + reason: "Bad version number".into(), + }); + } + + let iv = &body[IV_OFFSET..(IV_LENGTH + IV_OFFSET)]; + let mac = &body[(body.len() - 32)..]; + let cipher_text = &body[16 + 1..(body.len() - CIPHER_KEY_SIZE)]; + let iv_and_cipher_text = &body[0..(body.len() - CIPHER_KEY_SIZE)]; + debug_assert_eq!(iv.len(), IV_LENGTH); + debug_assert_eq!(mac.len(), 32); + + let agreement = + master_ephemeral.calculate_agreement(&key_pair.private())?; + let hkdf = libsignal_protocol::create_hkdf(&self.ctx, 3)?; + + let shared_secrets = hkdf.derive_secrets( + 64, + &agreement, + &[], + b"TextSecure Provisioning Message", + )?; + + let parts1 = &shared_secrets[0..32]; + let parts2 = &shared_secrets[32..]; + + let mut verifier = Hmac::::new_varkey(&parts2) + .expect("HMAC can take any size key"); + verifier.update(&iv_and_cipher_text); + let our_mac = verifier.finalize().into_bytes(); + debug_assert_eq!(our_mac.len(), mac.len()); + if &our_mac[..32] != mac { + return Err(ProvisioningError::InvalidData { + reason: "wrong MAC".into(), + }); + } + + // libsignal-service-java uses Pkcs5, + // but that should not matter. + // https://crypto.stackexchange.com/questions/9043/what-is-the-difference-between-pkcs5-padding-and-pkcs7-padding + let cipher = Cbc::::new_var(&parts1, &iv) + .expect("initalization of CBC/AES/PKCS7"); + let input = cipher.decrypt_vec(cipher_text).map_err(|e| { + ProvisioningError::InvalidData { + reason: format!("CBC/Padding error: {:?}", e), + } + })?; + + Ok(prost::Message::decode(Bytes::from(input))?) + } +} + +#[pin_project] +pub struct ProvisioningPipe { + ws: WS, + #[pin] + stream: WS::Stream, + provisioning_cipher: ProvisioningCipher, +} + +#[derive(Debug)] +pub enum ProvisioningStep { + Url(Url), + Message(ProvisionMessage), +} + +impl ProvisioningPipe { + pub fn from_socket( + ws: WS, + stream: WS::Stream, + ctx: &Context, + ) -> Result { + Ok(ProvisioningPipe { + ws, + stream, + provisioning_cipher: ProvisioningCipher::new(ctx.clone())?, + }) + } + + async fn send_ok_response( + &mut self, + id: Option, + ) -> Result<(), ProvisioningError> { + self.send_response(WebSocketResponseMessage { + id, + status: Some(200), + message: Some("OK".into()), + body: None, + headers: vec![], + }) + .await + } + + async fn send_response( + &mut self, + r: WebSocketResponseMessage, + ) -> Result<(), ProvisioningError> { + let msg = WebSocketMessage { + r#type: Some(web_socket_message::Type::Response.into()), + response: Some(r), + ..Default::default() + }; + let mut buffer = BytesMut::with_capacity(msg.encoded_len()); + msg.encode(&mut buffer).unwrap(); + Ok(self.ws.send_message(buffer.into()).await?) + } + + /// Worker task that + async fn run( + mut self, + mut sink: Sender>, + ) -> Result<(), mpsc::SendError> { + use futures::future::LocalBoxFuture; + + // This is a runtime-agnostic, poor man's `::spawn(Future)`. + let mut background_work = FuturesUnordered::>::new(); + // a pending task is added, as to never end the background worker until + // it's dropped. + background_work.push(futures::future::pending().boxed_local()); + + loop { + futures::select! { + // WebsocketConnection::onMessage(ByteString) + frame = self.stream.next() => match frame { + Some(WebSocketStreamItem::Message(frame)) => { + let env = self.process_frame(frame).await.transpose(); + if let Some(env) = env { + sink.send(env).await?; + } + }, + // TODO: implement keep-alive? + Some(WebSocketStreamItem::KeepAliveRequest) => continue, + None => break, + }, + _ = background_work.next() => { + // no op + }, + complete => { + log::info!("select! complete"); + } + } + } + + Ok(()) + } + + async fn process_frame( + &mut self, + frame: Bytes, + ) -> Result, ProvisioningError> { + let msg = WebSocketMessage::decode(frame)?; + use web_socket_message::Type; + match (msg.r#type(), msg.request, msg.response) { + (Type::Request, Some(request), _) => { + match request { + // step 1: we get a ProvisioningUUID that we need to build a + // registration link + WebSocketRequestMessage { + id, + verb, + path, + body, + .. + } if verb == Some("PUT".into()) + && path == Some("/v1/address".into()) => + { + let uuid: ProvisioningUuid = + prost::Message::decode(Bytes::from(body.unwrap()))?; + let mut provisioning_url = Url::parse("tsdevice://") + .map_err(|e| ProvisioningError::WsError { + reason: e.to_string(), + })?; + provisioning_url + .query_pairs_mut() + .append_pair("uuid", &uuid.uuid.unwrap()) + .append_pair( + "pub_key", + &format!( + "{}", + self.provisioning_cipher.public_key() + ), + ); + + // acknowledge + self.send_ok_response(id).await?; + + Ok(Some(ProvisioningStep::Url(provisioning_url))) + } + // step 2: once the QR code is scanned by the (already + // validated) main device + // we get a ProvisionMessage, that contains a bunch of + // useful things + WebSocketRequestMessage { + id, + verb, + path, + body, + .. + } if verb == Some("PUT".into()) + && path == Some("/v1/message".into()) => + { + let provision_envelope: ProvisionEnvelope = + prost::Message::decode(Bytes::from(body.unwrap()))?; + let provision_message = self + .provisioning_cipher + .decrypt(provision_envelope)?; + + // acknowledge + self.send_ok_response(id).await?; + + Ok(Some(ProvisioningStep::Message(provision_message))) + } + _ => Err(ProvisioningError::WsError { + reason: "Incorrect request".into(), + }), + } + } + _ => Err(ProvisioningError::WsError { + reason: "Incorrect request".into(), + }), + } + } + + pub fn stream( + self, + ) -> impl Stream> { + let (sink, stream) = mpsc::channel(1); + + let stream = stream.map(Some); + let runner = self.run(sink).map(|_| { + log::info!("Sink closed, provisioning is done!"); + None + }); + + let combined = futures::stream::select(stream, runner.into_stream()); + combined.filter_map(|x| async { x }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encrypt_provisioning_roundtrip() { + let ctx = Context::default(); + let cipher = ProvisioningCipher::new(ctx.clone()).unwrap(); + let encrypt_cipher = + ProvisioningCipher::from_public(ctx.clone(), cipher.public_key()); + + assert_eq!( + cipher.public_key(), + encrypt_cipher.public_key(), + "copy public key" + ); + + let msg = ProvisionMessage::default(); + let encrypted = encrypt_cipher.encrypt(msg.clone()).unwrap(); + + assert!(matches!( + encrypt_cipher.decrypt(encrypted.clone()), + Err(ProvisioningError::EncryptOnlyProvisioningCipher) + )); + + let decrypted = cipher.decrypt(encrypted).expect("decryptability"); + assert_eq!(msg, decrypted); + } +} diff --git a/libsignal-service/src/push_service.rs b/libsignal-service/src/push_service.rs index d5801e49f..89822793b 100644 --- a/libsignal-service/src/push_service.rs +++ b/libsignal-service/src/push_service.rs @@ -61,7 +61,7 @@ pub const STICKER_PATH: &str = "stickers/%s/full/%d"; **/ pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55); -pub const DEFAULT_DEVICE_ID: i32 = 1; +pub const DEFAULT_DEVICE_ID: u32 = 1; #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -244,7 +244,7 @@ pub enum ServiceError { MacError, #[error("Protocol error: {0}")] - SignalProtocolError(#[from] libsignal_protocol::Error), + SignalProtocolError(#[from] libsignal_protocol::error::SignalProtocolError), #[error("{0:?}")] MismatchedDevicesException(MismatchedDevices), @@ -534,7 +534,7 @@ pub trait PushService { &mut self, context: &Context, destination: &ServiceAddress, - device_id: i32, + device_id: u32, ) -> Result, ServiceError> { let path = match (device_id, destination.relay.as_ref()) { (1, None) => format!("/v2/keys/{}/*", destination.identifier()), diff --git a/libsignal-service/src/sealed_session_cipher.rs b/libsignal-service/src/sealed_session_cipher.rs index 628358d4c..06a5be4d1 100644 --- a/libsignal-service/src/sealed_session_cipher.rs +++ b/libsignal-service/src/sealed_session_cipher.rs @@ -102,7 +102,7 @@ pub struct UnidentifiedSenderMessageContent { pub struct SenderCertificate { signer: ServerCertificate, key: PublicKey, - sender_device_id: i32, + sender_device_id: u32, sender_uuid: Option, sender_e164: Option, expiration: u64, @@ -140,7 +140,7 @@ pub struct CertificateValidator { pub(crate) struct DecryptionResult { pub sender_uuid: Option, pub sender_e164: Option, - pub device_id: i32, + pub device_id: u32, pub padded_message: Vec, pub version: u32, } @@ -600,7 +600,7 @@ impl SenderCertificate { )?, sender_e164, sender_uuid, - sender_device_id: sender_device_id as i32, + sender_device_id, expiration: expires, certificate, signature, diff --git a/libsignal-service/src/sender.rs b/libsignal-service/src/sender.rs index cebdb7105..b3c105013 100644 --- a/libsignal-service/src/sender.rs +++ b/libsignal-service/src/sender.rs @@ -69,7 +69,7 @@ pub struct AttachmentSpec { pub struct MessageSender { service: Service, cipher: ServiceCipher, - device_id: i32, + device_id: u32, } #[derive(thiserror::Error, Debug)] @@ -86,7 +86,7 @@ pub enum MessageSenderError { #[error("{0}")] ServiceError(#[from] ServiceError), #[error("protocol error: {0}")] - ProtocolError(#[from] libsignal_protocol::Error), + ProtocolError(#[from] libsignal_protocol::error::SignalProtocolError), #[error("Failed to upload attachment {0}")] AttachmentUploadError(#[from] AttachmentUploadError), @@ -119,7 +119,7 @@ where pub fn new( service: Service, cipher: ServiceCipher, - device_id: i32, + device_id: u32, ) -> Self { MessageSender { service, @@ -680,7 +680,7 @@ where &mut self, recipient: &ServiceAddress, unidentified_access: Option<&UnidentifiedAccess>, - device_id: i32, + device_id: u32, content: &[u8], ) -> Result { let recipient_address = get_preferred_protocol_address( @@ -702,7 +702,7 @@ where .await?; for pre_key_bundle in pre_keys { if recipient.matches(&self.cipher.local_address) - && self.device_id == pre_key_bundle.device_id() + && self.device_id == pre_key_bundle.device_id()? { trace!("not establishing a session with myself!"); continue; @@ -711,7 +711,7 @@ where let pre_key_address = get_preferred_protocol_address( &self.cipher.store_context, recipient.clone(), - pre_key_bundle.device_id(), + pre_key_bundle.device_id()?, )?; let session_builder = SessionBuilder::new( &self.cipher.context, diff --git a/libsignal-service/src/utils.rs b/libsignal-service/src/utils.rs index 32f39c77b..d84a28898 100644 --- a/libsignal-service/src/utils.rs +++ b/libsignal-service/src/utils.rs @@ -57,7 +57,7 @@ pub mod serde_optional_base64 { } pub mod serde_public_key { - use libsignal_protocol::keys::PublicKey; + use libsignal_protocol::PublicKey; use serde::Serializer; pub fn serialize( @@ -67,8 +67,7 @@ pub mod serde_public_key { where S: Serializer, { - use serde::ser::Error; - serializer - .serialize_str(&public_key.to_base64().map_err(S::Error::custom)?) + let public_key = public_key.serialize(); + serializer.serialize_str(&base64::encode(&public_key)) } }