Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sending SyncMessage to self #263

Merged
merged 5 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions libsignal-service/src/proto.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#![allow(clippy::all)]

use rand::{Rng, RngCore};
include!(concat!(env!("OUT_DIR"), "/signalservice.rs"));
include!(concat!(env!("OUT_DIR"), "/signal.rs"));

Expand Down Expand Up @@ -65,3 +67,17 @@ impl WebSocketResponseMessage {
}
}
}

impl SyncMessage {
pub fn with_padding() -> Self {
Comment on lines +71 to +72
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gferon Shouldn't we make those stronger types, such that it's impossible to transmit without padding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was a ninja trick for the time being.

let mut rng = rand::thread_rng();
let random_size = rng.gen_range(1..=512);
let mut padding: Vec<u8> = vec![0; random_size];
rng.fill_bytes(&mut padding);

Self {
padding: Some(padding),
..Default::default()
}
}
}
136 changes: 62 additions & 74 deletions libsignal-service/src/sender.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::time::SystemTime;
use std::{collections::HashSet, time::SystemTime};

use chrono::prelude::*;
use libsignal_protocol::{
process_prekey_bundle, DeviceId, ProtocolStore, SenderCertificate,
SenderKeyStore, SignalProtocolError,
};
use log::{info, trace};
use log::{debug, info, trace};
use rand::{CryptoRng, Rng};
use uuid::Uuid;

Expand Down Expand Up @@ -77,7 +77,7 @@ pub struct AttachmentSpec {
/// Equivalent of Java's `SignalServiceMessageSender`.
#[derive(Clone)]
pub struct MessageSender<Service, S, R> {
ws: SignalWebSocket,
identified_ws: SignalWebSocket,
unidentified_ws: SignalWebSocket,
service: Service,
cipher: ServiceCipher<S, R>,
Expand Down Expand Up @@ -126,7 +126,7 @@ where
{
#[allow(clippy::too_many_arguments)]
pub fn new(
ws: SignalWebSocket,
identified_ws: SignalWebSocket,
unidentified_ws: SignalWebSocket,
service: Service,
cipher: ServiceCipher<S, R>,
Expand All @@ -137,7 +137,7 @@ where
) -> Self {
MessageSender {
service,
ws,
identified_ws,
unidentified_ws,
cipher,
csprng,
Expand Down Expand Up @@ -192,7 +192,10 @@ where

// Request upload attributes
log::trace!("Requesting upload attributes");
let attrs = self.ws.get_attachment_v2_upload_attributes().await?;
let attrs = self
.identified_ws
.get_attachment_v2_upload_attributes()
.await?;

log::trace!("Uploading attachment");
let (id, digest) = self
Expand Down Expand Up @@ -287,29 +290,35 @@ where
pub async fn send_message(
&mut self,
recipient: &ServiceAddress,
unidentified_access: Option<UnidentifiedAccess>,
mut unidentified_access: Option<UnidentifiedAccess>,
message: impl Into<ContentBody>,
timestamp: u64,
online: bool,
) -> SendMessageResult {
let content_body = message.into();

use crate::proto::data_message::Flags;

let end_session = match &content_body {
ContentBody::DataMessage(message) => {
unidentified_access.take(); // don't send end session as sealed sender
message.flags == Some(Flags::EndSession as u32)
},
_ => false,
};

// we never send sync messages or to our own account as sealed sender
if recipient == &self.local_address
|| matches!(&content_body, ContentBody::SynchronizeMessage(_))
{
unidentified_access.take();
}

// try to send the original message to all the recipient's devices
let mut results = vec![
self.try_send_message(
*recipient,
if !end_session {
unidentified_access.as_ref()
} else {
None
},
unidentified_access.as_ref(),
&content_body,
timestamp,
online,
Expand Down Expand Up @@ -433,14 +442,16 @@ where
online: bool,
) -> SendMessageResult {
use prost::Message;

let content = content_body.clone().into_proto();

let content_bytes = content.encode_to_vec();

for _ in 0..4u8 {
let messages = self
.create_encrypted_messages(
&recipient,
unidentified_access.as_ref().map(|x| &x.certificate),
unidentified_access.map(|x| &x.certificate),
&content_bytes,
)
.await?;
Expand All @@ -453,13 +464,13 @@ where
};

let send = if let Some(unidentified) = &unidentified_access {
log::trace!("Sending via unidentified");
log::debug!("sending via unidentified");
self.unidentified_ws
.send_messages_unidentified(messages, unidentified)
.await
} else {
log::trace!("Sending identified");
self.ws.send_messages(messages).await
log::debug!("sending identified");
self.identified_ws.send_messages(messages).await
};

match send {
Expand Down Expand Up @@ -584,7 +595,7 @@ where
blob: Some(ptr),
complete: Some(complete),
}),
..Default::default()
..SyncMessage::with_padding()
};

self.send_message(
Expand All @@ -607,61 +618,34 @@ where
content: &[u8],
) -> Result<Vec<OutgoingPushMessage>, MessageSenderError> {
let mut messages = vec![];
let mut devices = vec![];

let myself = *recipient == self.local_address;
if !myself || unidentified_access.is_some() {
trace!("sending message to default device");
let mut devices: HashSet<DeviceId> = self
.protocol_store
.get_sub_device_sessions(recipient)
.await?
.into_iter()
.map(DeviceId::from)
.collect();

// always send to the primary device no matter what
devices.insert(DEFAULT_DEVICE_ID.into());

// when sending an identified message, remove ourselves from the list of recipients
if unidentified_access.is_none() {
devices.remove(&self.device_id);
}

for device_id in devices {
debug!("sending message to device {}", device_id);
messages.push(
self.create_encrypted_message(
recipient,
unidentified_access,
DEFAULT_DEVICE_ID.into(),
device_id,
content,
)
.await?,
);
} else {
devices.push(DEFAULT_DEVICE_ID);
}

devices.extend(
self.protocol_store
.get_sub_device_sessions(recipient)
.await?,
);
devices.sort_unstable();

let original_device_count = devices.len();
devices.dedup();
if devices.len() != original_device_count {
log::warn!("SessionStoreExt::get_sub_device_sessions should return unique device id's, and should not include DEFAULT_DEVICE_ID.");
}

// When sending to ourselves, don't include the local device
if myself {
devices.retain(|id| DeviceId::from(*id) != self.device_id);
}

for device_id in devices {
trace!("sending message to device {}", device_id);
let ppa = get_preferred_protocol_address(
&self.protocol_store,
recipient,
device_id.into(),
)
.await?;
if self.protocol_store.load_session(&ppa).await?.is_some() {
messages.push(
self.create_encrypted_message(
recipient,
unidentified_access,
device_id.into(),
content,
)
.await?,
)
}
}

Ok(messages)
Expand All @@ -677,21 +661,23 @@ where
device_id: DeviceId,
content: &[u8],
) -> Result<OutgoingPushMessage, MessageSenderError> {
let recipient_address = get_preferred_protocol_address(
&self.protocol_store,
recipient,
device_id,
)
.await?;
log::trace!("encrypting message for {:?}", recipient_address);
let recipient_protocol_address =
recipient.to_protocol_address(device_id);

log::debug!("encrypting message for {}", recipient_protocol_address);

// establish a session with the recipient/device if necessary
// no need to establish a session with ourselves (and our own current device)
if self
.protocol_store
.load_session(&recipient_address)
.load_session(&recipient_protocol_address)
.await?
.is_none()
{
info!("establishing new session with {:?}", recipient_address);
info!(
"establishing new session with {}",
recipient_protocol_address
);
let pre_keys = match self
.service
.get_pre_keys(recipient, device_id.into())
Expand All @@ -708,6 +694,7 @@ where
},
Err(e) => Err(e)?,
};

for pre_key_bundle in pre_keys {
if recipient == &self.local_address
&& self.device_id == pre_key_bundle.device_id()?
Expand Down Expand Up @@ -737,8 +724,9 @@ where

let message = self
.cipher
.encrypt(&recipient_address, unidentified_access, content)
.encrypt(&recipient_protocol_address, unidentified_access, content)
.await?;

Ok(message)
}

Expand Down Expand Up @@ -786,7 +774,7 @@ where
unidentified_status,
..Default::default()
}),
..Default::default()
..SyncMessage::with_padding()
})
}
}
8 changes: 4 additions & 4 deletions libsignal-service/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
frame: Bytes,
) -> Result<(), ServiceError> {
let msg = WebSocketMessage::decode(frame)?;
log::trace!("Decoded {:?}", msg);
log::trace!("decoded {:?}", msg);

use web_socket_message::Type;
match (msg.r#type(), msg.request, msg.response) {
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
.filter(|x| !self.outgoing_request_map.contains_key(x))
.unwrap_or_else(|| self.next_request_id()),
);
log::trace!("Sending request {:?}", request);
log::trace!("sending request {:?}", request);

self.outgoing_request_map.insert(request.id.unwrap(), responder);
let msg = WebSocketMessage {
Expand Down Expand Up @@ -227,7 +227,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
self.ws.send_message(buffer.into()).await?
}
Some(WebSocketStreamItem::KeepAliveRequest) => {
log::debug!("keep alive is disabled: ignoring request");
log::trace!("keep alive is disabled: ignoring request");
}
None => {
return Err(ServiceError::WsError {
Expand All @@ -239,7 +239,7 @@ impl<WS: WebSocketService> SignalWebSocketProcess<WS> {
response = self.outgoing_responses.next() => {
match response {
Some(Ok(response)) => {
log::trace!("Sending response {:?}", response);
log::trace!("sending response {:?}", response);

let msg = WebSocketMessage {
r#type: Some(web_socket_message::Type::Response.into()),
Expand Down
Loading