Skip to content

Commit

Permalink
Fix sending SyncMessage to self (#263)
Browse files Browse the repository at this point in the history
- We need to make sure any message sent to self is _never_ send unidentified
- Make sure we add padding to all `SyncMessage`
- Remove some redundant session loading before creating encrypted messages
  • Loading branch information
gferon authored Nov 19, 2023
1 parent afb5114 commit 6fc62c8
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 78 deletions.
16 changes: 16 additions & 0 deletions libsignal-service/src/proto.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#![allow(clippy::all)]

use rand::{Rng, RngCore};
include!(concat!(env!("OUT_DIR"), "/signalservice.rs"));
include!(concat!(env!("OUT_DIR"), "/signal.rs"));

Expand Down Expand Up @@ -65,3 +67,17 @@ impl WebSocketResponseMessage {
}
}
}

impl SyncMessage {
pub fn with_padding() -> Self {
let mut rng = rand::thread_rng();
let random_size = rng.gen_range(1..=512);
let mut padding: Vec<u8> = vec![0; random_size];
rng.fill_bytes(&mut padding);

Self {
padding: Some(padding),
..Default::default()
}
}
}
136 changes: 62 additions & 74 deletions libsignal-service/src/sender.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::time::SystemTime;
use std::{collections::HashSet, time::SystemTime};

use chrono::prelude::*;
use libsignal_protocol::{
process_prekey_bundle, DeviceId, ProtocolStore, SenderCertificate,
SenderKeyStore, SignalProtocolError,
};
use log::{info, trace};
use log::{debug, info, trace};
use rand::{CryptoRng, Rng};
use uuid::Uuid;

Expand Down Expand Up @@ -77,7 +77,7 @@ pub struct AttachmentSpec {
/// Equivalent of Java's `SignalServiceMessageSender`.
#[derive(Clone)]
pub struct MessageSender<Service, S, R> {
ws: SignalWebSocket,
identified_ws: SignalWebSocket,
unidentified_ws: SignalWebSocket,
service: Service,
cipher: ServiceCipher<S, R>,
Expand Down Expand Up @@ -126,7 +126,7 @@ where
{
#[allow(clippy::too_many_arguments)]
pub fn new(
ws: SignalWebSocket,
identified_ws: SignalWebSocket,
unidentified_ws: SignalWebSocket,
service: Service,
cipher: ServiceCipher<S, R>,
Expand All @@ -137,7 +137,7 @@ where
) -> Self {
MessageSender {
service,
ws,
identified_ws,
unidentified_ws,
cipher,
csprng,
Expand Down Expand Up @@ -192,7 +192,10 @@ where

// Request upload attributes
log::trace!("Requesting upload attributes");
let attrs = self.ws.get_attachment_v2_upload_attributes().await?;
let attrs = self
.identified_ws
.get_attachment_v2_upload_attributes()
.await?;

log::trace!("Uploading attachment");
let (id, digest) = self
Expand Down Expand Up @@ -287,29 +290,35 @@ where
pub async fn send_message(
&mut self,
recipient: &ServiceAddress,
unidentified_access: Option<UnidentifiedAccess>,
mut unidentified_access: Option<UnidentifiedAccess>,
message: impl Into<ContentBody>,
timestamp: u64,
online: bool,
) -> SendMessageResult {
let content_body = message.into();

use crate::proto::data_message::Flags;

let end_session = match &content_body {
ContentBody::DataMessage(message) => {
unidentified_access.take(); // don't send end session as sealed sender
message.flags == Some(Flags::EndSession as u32)
},
_ => false,
};

// we never send sync messages or to our own account as sealed sender
if recipient == &self.local_address
|| matches!(&content_body, ContentBody::SynchronizeMessage(_))
{
unidentified_access.take();
}

// try to send the original message to all the recipient's devices
let mut results = vec![
self.try_send_message(
*recipient,
if !end_session {
unidentified_access.as_ref()
} else {
None
},
unidentified_access.as_ref(),
&content_body,
timestamp,
online,
Expand Down Expand Up @@ -433,14 +442,16 @@ where
online: bool,
) -> SendMessageResult {
use prost::Message;

let content = content_body.clone().into_proto();

let content_bytes = content.encode_to_vec();

for _ in 0..4u8 {
let messages = self
.create_encrypted_messages(
&recipient,
unidentified_access.as_ref().map(|x| &x.certificate),
unidentified_access.map(|x| &x.certificate),
&content_bytes,
)
.await?;
Expand All @@ -453,13 +464,13 @@ where
};

let send = if let Some(unidentified) = &unidentified_access {
log::trace!("Sending via unidentified");
log::debug!("sending via unidentified");
self.unidentified_ws
.send_messages_unidentified(messages, unidentified)
.await
} else {
log::trace!("Sending identified");
self.ws.send_messages(messages).await
log::debug!("sending identified");
self.identified_ws.send_messages(messages).await
};

match send {
Expand Down Expand Up @@ -584,7 +595,7 @@ where
blob: Some(ptr),
complete: Some(complete),
}),
..Default::default()
..SyncMessage::with_padding()
};

self.send_message(
Expand All @@ -607,61 +618,34 @@ where
content: &[u8],
) -> Result<Vec<OutgoingPushMessage>, MessageSenderError> {
let mut messages = vec![];
let mut devices = vec![];

let myself = *recipient == self.local_address;
if !myself || unidentified_access.is_some() {
trace!("sending message to default device");
let mut devices: HashSet<DeviceId> = self
.protocol_store
.get_sub_device_sessions(recipient)
.await?
.into_iter()
.map(DeviceId::from)
.collect();

// always send to the primary device no matter what
devices.insert(DEFAULT_DEVICE_ID.into());

// when sending an identified message, remove ourselves from the list of recipients
if unidentified_access.is_none() {
devices.remove(&self.device_id);
}

for device_id in devices {
debug!("sending message to device {}", device_id);
messages.push(
self.create_encrypted_message(
recipient,
unidentified_access,
DEFAULT_DEVICE_ID.into(),
device_id,
content,
)
.await?,
);
} else {
devices.push(DEFAULT_DEVICE_ID);
}

devices.extend(
self.protocol_store
.get_sub_device_sessions(recipient)
.await?,
);
devices.sort_unstable();

let original_device_count = devices.len();
devices.dedup();
if devices.len() != original_device_count {
log::warn!("SessionStoreExt::get_sub_device_sessions should return unique device id's, and should not include DEFAULT_DEVICE_ID.");
}

// When sending to ourselves, don't include the local device
if myself {
devices.retain(|id| DeviceId::from(*id) != self.device_id);
}

for device_id in devices {
trace!("sending message to device {}", device_id);
let ppa = get_preferred_protocol_address(
&self.protocol_store,
recipient,
device_id.into(),
)
.await?;
if self.protocol_store.load_session(&ppa).await?.is_some() {
messages.push(
self.create_encrypted_message(
recipient,
unidentified_access,
device_id.into(),
content,
)
.await?,
)
}
}

Ok(messages)
Expand All @@ -677,21 +661,23 @@ where
device_id: DeviceId,
content: &[u8],
) -> Result<OutgoingPushMessage, MessageSenderError> {
let recipient_address = get_preferred_protocol_address(
&self.protocol_store,
recipient,
device_id,
)
.await?;
log::trace!("encrypting message for {:?}", recipient_address);
let recipient_protocol_address =
recipient.to_protocol_address(device_id);

log::debug!("encrypting message for {}", recipient_protocol_address);

// establish a session with the recipient/device if necessary
// no need to establish a session with ourselves (and our own current device)
if self
.protocol_store
.load_session(&recipient_address)
.load_session(&recipient_protocol_address)
.await?
.is_none()
{
info!("establishing new session with {:?}", recipient_address);
info!(
"establishing new session with {}",
recipient_protocol_address
);
let pre_keys = match self
.service
.get_pre_keys(recipient, device_id.into())
Expand All @@ -708,6 +694,7 @@ where
},
Err(e) => Err(e)?,
};

for pre_key_bundle in pre_keys {
if recipient == &self.local_address
&& self.device_id == pre_key_bundle.device_id()?
Expand Down Expand Up @@ -737,8 +724,9 @@ where

let message = self
.cipher
.encrypt(&recipient_address, unidentified_access, content)
.encrypt(&recipient_protocol_address, unidentified_access, content)
.await?;

Ok(message)
}

Expand Down Expand Up @@ -786,7 +774,7 @@ where
unidentified_status,
..Default::default()
}),
..Default::default()
..SyncMessage::with_padding()
})
}
}
8 changes: 4 additions & 4 deletions libsignal-service/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
frame: Bytes,
) -> Result<(), ServiceError> {
let msg = WebSocketMessage::decode(frame)?;
log::trace!("Decoded {:?}", msg);
log::trace!("decoded {:?}", msg);

use web_socket_message::Type;
match (msg.r#type(), msg.request, msg.response) {
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
.filter(|x| !self.outgoing_request_map.contains_key(x))
.unwrap_or_else(|| self.next_request_id()),
);
log::trace!("Sending request {:?}", request);
log::trace!("sending request {:?}", request);

self.outgoing_request_map.insert(request.id.unwrap(), responder);
let msg = WebSocketMessage {
Expand Down Expand Up @@ -227,7 +227,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
self.ws.send_message(buffer.into()).await?
}
Some(WebSocketStreamItem::KeepAliveRequest) => {
log::debug!("keep alive is disabled: ignoring request");
log::trace!("keep alive is disabled: ignoring request");
}
None => {
return Err(ServiceError::WsError {
Expand All @@ -239,7 +239,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
response = self.outgoing_responses.next() => {
match response {
Some(Ok(response)) => {
log::trace!("Sending response {:?}", response);
log::trace!("sending response {:?}", response);

let msg = WebSocketMessage {
r#type: Some(web_socket_message::Type::Response.into()),
Expand Down

0 comments on commit 6fc62c8

Please sign in to comment.