Skip to content

Commit

Permalink
Mark internally generated sync messages as such, to be able to proper…
Browse files Browse the repository at this point in the history
…ly set the device IDs
  • Loading branch information
gferon committed Nov 19, 2023
1 parent afb5114 commit 88d9680
Showing 1 changed file with 58 additions and 58 deletions.
116 changes: 58 additions & 58 deletions libsignal-service/src/sender.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::time::SystemTime;
use std::{collections::HashSet, time::SystemTime};

use chrono::prelude::*;
use libsignal_protocol::{
process_prekey_bundle, DeviceId, ProtocolStore, SenderCertificate,
SenderKeyStore, SignalProtocolError,
};
use log::{info, trace};
use log::{debug, info, trace};
use rand::{CryptoRng, Rng};
use uuid::Uuid;

Expand Down Expand Up @@ -77,7 +77,7 @@ pub struct AttachmentSpec {
/// Equivalent of Java's `SignalServiceMessageSender`.
#[derive(Clone)]
pub struct MessageSender<Service, S, R> {
ws: SignalWebSocket,
identified_ws: SignalWebSocket,
unidentified_ws: SignalWebSocket,
service: Service,
cipher: ServiceCipher<S, R>,
Expand Down Expand Up @@ -137,7 +137,7 @@ where
) -> Self {
MessageSender {
service,
ws,
identified_ws: ws,
unidentified_ws,
cipher,
csprng,
Expand Down Expand Up @@ -192,7 +192,10 @@ where

// Request upload attributes
log::trace!("Requesting upload attributes");
let attrs = self.ws.get_attachment_v2_upload_attributes().await?;
let attrs = self
.identified_ws
.get_attachment_v2_upload_attributes()
.await?;

log::trace!("Uploading attachment");
let (id, digest) = self
Expand Down Expand Up @@ -302,6 +305,7 @@ where
_ => false,
};

// try to send the original message to all the recipient's devices
let mut results = vec![
self.try_send_message(
*recipient,
Expand All @@ -313,6 +317,7 @@ where
&content_body,
timestamp,
online,
false,
)
.await,
];
Expand All @@ -337,6 +342,7 @@ where
&sync_message,
timestamp,
false,
true,
)
.await?;
},
Expand Down Expand Up @@ -377,6 +383,7 @@ where
&content_body,
timestamp,
online,
false,
)
.await;

Expand Down Expand Up @@ -413,6 +420,7 @@ where
&sync_message,
timestamp,
false,
true,
)
.await;

Expand All @@ -431,17 +439,21 @@ where
content_body: &ContentBody,
timestamp: u64,
online: bool,
is_sync_message: bool,
) -> SendMessageResult {
use prost::Message;

let content = content_body.clone().into_proto();

let content_bytes = content.encode_to_vec();

for _ in 0..4u8 {
let messages = self
.create_encrypted_messages(
&recipient,
unidentified_access.as_ref().map(|x| &x.certificate),
unidentified_access.map(|x| &x.certificate),
&content_bytes,
is_sync_message,
)
.await?;

Expand All @@ -453,13 +465,13 @@ where
};

let send = if let Some(unidentified) = &unidentified_access {
log::trace!("Sending via unidentified");
log::debug!("sending via unidentified");
self.unidentified_ws
.send_messages_unidentified(messages, unidentified)
.await
} else {
log::trace!("Sending identified");
self.ws.send_messages(messages).await
log::debug!("sending identified");
self.identified_ws.send_messages(messages).await
};

match send {
Expand Down Expand Up @@ -605,53 +617,35 @@ where
recipient: &ServiceAddress,
unidentified_access: Option<&SenderCertificate>,
content: &[u8],
is_sync_message: bool,
) -> Result<Vec<OutgoingPushMessage>, MessageSenderError> {
let mut messages = vec![];
let mut devices = vec![];

let myself = *recipient == self.local_address;
if !myself || unidentified_access.is_some() {
trace!("sending message to default device");
messages.push(
self.create_encrypted_message(
recipient,
unidentified_access,
DEFAULT_DEVICE_ID.into(),
content,
)
.await?,
);
} else {
devices.push(DEFAULT_DEVICE_ID);
}
let mut devices: HashSet<DeviceId> = self
.protocol_store
.get_sub_device_sessions(recipient)
.await?
.into_iter()
.map(DeviceId::from)
.collect();

devices.extend(
self.protocol_store
.get_sub_device_sessions(recipient)
.await?,
);
devices.sort_unstable();

let original_device_count = devices.len();
devices.dedup();
if devices.len() != original_device_count {
log::warn!("SessionStoreExt::get_sub_device_sessions should return unique device id's, and should not include DEFAULT_DEVICE_ID.");
}
// always send to the primary device no matter what
devices.insert(DEFAULT_DEVICE_ID.into());

// When sending to ourselves, don't include the local device
if myself {
devices.retain(|id| DeviceId::from(*id) != self.device_id);
// when sending a sync message, we always send it to all devices except the current one
if is_sync_message {
devices.remove(&self.device_id);
}

for device_id in devices {
trace!("sending message to device {}", device_id);
let ppa = get_preferred_protocol_address(
&self.protocol_store,
recipient,
device_id.into(),
)
.await?;
if self.protocol_store.load_session(&ppa).await?.is_some() {
debug!("sending message to device {}", device_id);
if device_id == DEFAULT_DEVICE_ID.into()
|| self
.protocol_store
.load_session(&recipient.to_protocol_address(device_id))
.await?
.is_some()
{
messages.push(
self.create_encrypted_message(
recipient,
Expand All @@ -677,21 +671,25 @@ where
device_id: DeviceId,
content: &[u8],
) -> Result<OutgoingPushMessage, MessageSenderError> {
let recipient_address = get_preferred_protocol_address(
&self.protocol_store,
recipient,
device_id,
)
.await?;
log::trace!("encrypting message for {:?}", recipient_address);
let recipient_protocol_address =
recipient.to_protocol_address(device_id);

// let is_self =
// recipient == &self.local_address && device_id == self.device_id;
log::debug!("encrypting message for {}", recipient_protocol_address);

// establish a session with the recipient/device if necessary
// no need to establish a session with ourselves (and our own current device)
if self
.protocol_store
.load_session(&recipient_address)
.load_session(&recipient_protocol_address)
.await?
.is_none()
{
info!("establishing new session with {:?}", recipient_address);
info!(
"establishing new session with {}",
recipient_protocol_address
);
let pre_keys = match self
.service
.get_pre_keys(recipient, device_id.into())
Expand All @@ -708,6 +706,7 @@ where
},
Err(e) => Err(e)?,
};

for pre_key_bundle in pre_keys {
if recipient == &self.local_address
&& self.device_id == pre_key_bundle.device_id()?
Expand Down Expand Up @@ -737,8 +736,9 @@ where

let message = self
.cipher
.encrypt(&recipient_address, unidentified_access, content)
.encrypt(&recipient_protocol_address, unidentified_access, content)
.await?;

Ok(message)
}

Expand Down

0 comments on commit 88d9680

Please sign in to comment.