Skip to content

Commit

Permalink
Don't store csprng instances in structs anymore
Browse files Browse the repository at this point in the history
  • Loading branch information
gferon committed Nov 3, 2024
1 parent b59f8ad commit 839a99a
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 126 deletions.
78 changes: 42 additions & 36 deletions src/account_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ use crate::{

type Aes256Ctr128BE = ctr::Ctr128BE<aes::Aes256>;

pub struct AccountManager<R: Rng + CryptoRng> {
csprng: R,
pub struct AccountManager {
service: PushService,
profile_key: Option<ProfileKey>,
}
Expand All @@ -74,14 +73,9 @@ pub struct Profile {
pub avatar: Option<String>,
}

impl<R: Rng + CryptoRng> AccountManager<R> {
pub fn new(
csprng: R,
service: PushService,
profile_key: Option<ProfileKey>,
) -> Self {
impl AccountManager {
pub fn new(service: PushService, profile_key: Option<ProfileKey>) -> Self {
Self {
csprng,
service,
profile_key,
}
Expand All @@ -94,9 +88,10 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
///
/// Equivalent to Java's RefreshPreKeysJob
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip(self, protocol_store))]
pub async fn update_pre_key_bundle<P: PreKeysStore>(
#[tracing::instrument(skip(self, csprng, protocol_store))]
pub async fn update_pre_key_bundle<R: Rng + CryptoRng, P: PreKeysStore>(
&mut self,
csprng: &mut R,
protocol_store: &mut P,
service_id_type: ServiceIdType,
use_last_resort_key: bool,
Expand Down Expand Up @@ -156,8 +151,8 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
let (pre_keys, signed_pre_key, pq_pre_keys, pq_last_resort_key) =
crate::pre_keys::replenish_pre_keys(
protocol_store,
csprng,
&identity_key_pair,
&mut self.csprng,
use_last_resort_key && !has_last_resort_key,
PRE_KEY_BATCH_SIZE,
PRE_KEY_BATCH_SIZE,
Expand Down Expand Up @@ -283,8 +278,9 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
/// ```java
/// TextSecurePreferences.setIsUnidentifiedDeliveryEnabled(context, false);
/// ```
pub async fn link_device(
pub async fn link_device<R: Rng + CryptoRng>(
&mut self,
csprng: &mut R,
url: url::Url,
aci_identity_store: &dyn IdentityKeyStore,
pni_identity_store: &dyn IdentityKeyStore,
Expand Down Expand Up @@ -346,7 +342,7 @@ impl<R: Rng + CryptoRng> AccountManager<R> {

let cipher = ProvisioningCipher::from_public(pub_key);

let encrypted = cipher.encrypt(&mut self.csprng, msg)?;
let encrypted = cipher.encrypt(csprng, msg)?;
self.send_provisioning_message(ephemeral_id, encrypted)
.await?;
Ok(())
Expand Down Expand Up @@ -382,10 +378,12 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
}

pub async fn register_account<
R: Rng + CryptoRng,
Aci: PreKeysStore + IdentityKeyStore,
Pni: PreKeysStore + IdentityKeyStore,
>(
&mut self,
csprng: &mut R,
registration_method: RegistrationMethod<'_>,
account_attributes: AccountAttributes,
aci_protocol_store: &mut Aci,
Expand All @@ -408,8 +406,8 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
aci_last_resort_kyber_prekey,
) = crate::pre_keys::replenish_pre_keys(
aci_protocol_store,
csprng,
&aci_identity_key_pair,
&mut self.csprng,
true,
0,
0,
Expand All @@ -423,8 +421,8 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
pni_last_resort_kyber_prekey,
) = crate::pre_keys::replenish_pre_keys(
pni_protocol_store,
csprng,
&pni_identity_key_pair,
&mut self.csprng,
true,
0,
0,
Expand Down Expand Up @@ -470,15 +468,19 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
/// ```
/// in which the `retain_avatar` parameter sets whether to remove (`false`) or retain (`true`) the
/// currently set avatar.
pub async fn upload_versioned_profile_without_avatar<S: AsRef<str>>(
pub async fn upload_versioned_profile_without_avatar<
R: Rng + CryptoRng,
S: AsRef<str>,
>(
&mut self,
aci: libsignal_protocol::Aci,
name: ProfileName<S>,
about: Option<String>,
about_emoji: Option<String>,
retain_avatar: bool,
csprng: &mut R,
) -> Result<(), ProfileManagerError> {
self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _>(
self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _, _>(
aci,
name,
about,
Expand All @@ -488,6 +490,7 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
} else {
AvatarWrite::NoAvatar
},
csprng,
)
.await?;
Ok(())
Expand All @@ -505,8 +508,8 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
.retrieve_profile_by_id(address, Some(profile_key))
.await?;

let profile_cipher = ProfileCipher::new(&mut self.csprng, profile_key);
Ok(encrypted_profile.decrypt(profile_cipher)?)
let profile_cipher = ProfileCipher::new(profile_key);
Ok(profile_cipher.decrypt(encrypted_profile)?)
}

/// Upload a profile
Expand All @@ -517,6 +520,7 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
pub async fn upload_versioned_profile<
's,
C: std::io::Read + Send + 's,
R: Rng + CryptoRng,
S: AsRef<str>,
>(
&mut self,
Expand All @@ -525,18 +529,18 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
about: Option<String>,
about_emoji: Option<String>,
avatar: AvatarWrite<&'s mut C>,
csprng: &mut R,
) -> Result<Option<String>, ProfileManagerError> {
let profile_key =
self.profile_key.expect("set profile key in AccountManager");
let mut profile_cipher =
ProfileCipher::new(&mut self.csprng, profile_key);
let profile_cipher = ProfileCipher::new(profile_key);

// Profile encryption
let name = profile_cipher.encrypt_name(name.as_ref())?;
let name = profile_cipher.encrypt_name(name.as_ref(), csprng)?;
let about = about.unwrap_or_default();
let about = profile_cipher.encrypt_about(about)?;
let about = profile_cipher.encrypt_about(about, csprng)?;
let about_emoji = about_emoji.unwrap_or_default();
let about_emoji = profile_cipher.encrypt_emoji(about_emoji)?;
let about_emoji = profile_cipher.encrypt_emoji(about_emoji, csprng)?;

// If avatar -> upload
if matches!(avatar, AvatarWrite::NewAvatar(_)) {
Expand Down Expand Up @@ -573,13 +577,14 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
}

/// Update (encrypted) device name
pub async fn update_device_name(
pub async fn update_device_name<R: Rng + CryptoRng>(
&mut self,
device_name: &str,
public_key: &IdentityKey,
csprng: &mut R,
) -> Result<(), ServiceError> {
let encrypted_device_name =
encrypt_device_name(&mut self.csprng, device_name, public_key)?;
encrypt_device_name(csprng, device_name, public_key)?;

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -640,9 +645,9 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
/// Should be called as the primary device to migrate from pre-PNI to PNI.
///
/// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager.
#[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci), fields(local_aci = %local_aci))]
#[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = %local_aci))]
pub async fn pnp_initialize_devices<
// XXX So many constraints here, all imposed by the MessageSender
R: Rng + CryptoRng,
Aci: PreKeysStore + SessionStoreExt,
Pni: PreKeysStore,
AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
Expand All @@ -653,6 +658,7 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
mut sender: MessageSender<AciOrPni, R>,
local_aci: ServiceAddress,
e164: PhoneNumber,
csprng: &mut R,
) -> Result<(), MessageSenderError> {
let pni_identity_key_pair =
pni_protocol_store.get_identity_key_pair().await?;
Expand Down Expand Up @@ -709,25 +715,25 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
) = if local_device_id == DEFAULT_DEVICE_ID {
crate::pre_keys::replenish_pre_keys(
pni_protocol_store,
csprng,
&pni_identity_key_pair,
&mut self.csprng,
true,
0,
0,
)
.await?
} else {
// Generate a signed prekey
let signed_pre_key_pair = KeyPair::generate(&mut self.csprng);
let signed_pre_key_pair = KeyPair::generate(csprng);
let signed_pre_key_public = signed_pre_key_pair.public_key;
let signed_pre_key_signature =
pni_identity_key_pair.private_key().calculate_signature(
&signed_pre_key_public.serialize(),
&mut self.csprng,
csprng,
)?;

let signed_prekey_record = SignedPreKeyRecord::new(
self.csprng.gen_range::<u32, _>(0..0xFFFFFF).into(),
csprng.gen_range::<u32, _>(0..0xFFFFFF).into(),
Timestamp::now(),
&signed_pre_key_pair,
&signed_pre_key_signature,
Expand All @@ -736,7 +742,7 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
// Generate a last-resort Kyber prekey
let kyber_pre_key_record = KyberPreKeyRecord::generate(
kem::KeyType::Kyber1024,
self.csprng.gen_range::<u32, _>(0..0xFFFFFF).into(),
csprng.gen_range::<u32, _>(0..0xFFFFFF).into(),
pni_identity_key_pair.private_key(),
)?;
(
Expand All @@ -751,7 +757,7 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
pni_protocol_store.get_local_registration_id().await?
} else {
loop {
let regid = generate_registration_id(&mut self.csprng);
let regid = generate_registration_id(csprng);
if !pni_registration_ids.iter().any(|(_k, v)| *v == regid) {
break regid;
}
Expand Down Expand Up @@ -799,7 +805,7 @@ impl<R: Rng + CryptoRng> AccountManager<R> {
e164.format().mode(phonenumber::Mode::E164).to_string(),
),
}),
padding: Some(random_length_padding(&mut self.csprng, 512)),
padding: Some(random_length_padding(csprng, 512)),
..SyncMessage::default()
};
let content: ContentBody = msg.into();
Expand Down
34 changes: 16 additions & 18 deletions src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,17 @@ use crate::{
///
/// Equivalent of SignalServiceCipher in Java.
#[derive(Clone)]
pub struct ServiceCipher<S, R> {
pub struct ServiceCipher<S> {
protocol_store: S,
csprng: R,
trust_root: PublicKey,
local_uuid: Uuid,
local_device_id: u32,
}

impl<S, R> fmt::Debug for ServiceCipher<S, R> {
impl<S> fmt::Debug for ServiceCipher<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServiceCipher")
.field("protocol_store", &"...")
.field("csprng", &"...")
.field("trust_root", &"...")
.field("local_uuid", &self.local_uuid)
.field("local_device_id", &self.local_device_id)
Expand Down Expand Up @@ -70,21 +68,18 @@ fn debug_envelope(envelope: &Envelope) -> String {
}
}

impl<S, R> ServiceCipher<S, R>
impl<S> ServiceCipher<S>
where
S: ProtocolStore + SenderKeyStore + SessionStoreExt + Clone,
R: Rng + CryptoRng,
{
pub fn new(
protocol_store: S,
csprng: R,
trust_root: PublicKey,
local_uuid: Uuid,
local_device_id: u32,
) -> Self {
Self {
protocol_store,
csprng,
trust_root,
local_uuid,
local_device_id,
Expand All @@ -94,13 +89,14 @@ where
/// Opens ("decrypts") an envelope.
///
/// Envelopes may be empty, in which case this method returns `Ok(None)`
#[tracing::instrument(skip(envelope), fields(envelope = debug_envelope(&envelope)))]
pub async fn open_envelope(
#[tracing::instrument(skip(envelope, csprng), fields(envelope = debug_envelope(&envelope)))]
pub async fn open_envelope<R: Rng + CryptoRng>(
&mut self,
envelope: Envelope,
csprng: &mut R,
) -> Result<Option<Content>, ServiceError> {
if envelope.content.is_some() {
let plaintext = self.decrypt(&envelope).await?;
let plaintext = self.decrypt(&envelope, csprng).await?;
let message =
crate::proto::Content::decode(plaintext.data.as_slice())?;
if let Some(bytes) = message.sender_key_distribution_message {
Expand All @@ -126,10 +122,11 @@ where
/// Triage of legacy messages happens inside this method, as opposed to the
/// Java implementation, because it makes the borrow checker and the
/// author happier.
#[tracing::instrument(skip(envelope), fields(envelope = debug_envelope(envelope)))]
async fn decrypt(
#[tracing::instrument(skip(envelope, csprng), fields(envelope = debug_envelope(envelope)))]
async fn decrypt<R: Rng + CryptoRng>(
&mut self,
envelope: &Envelope,
csprng: &mut R,
) -> Result<Plaintext, ServiceError> {
let ciphertext = if let Some(msg) = envelope.content.as_ref() {
msg
Expand Down Expand Up @@ -180,7 +177,7 @@ where
&mut self.protocol_store.clone(),
&self.protocol_store.clone(),
&mut self.protocol_store.clone(),
&mut self.csprng,
csprng,
)
.await?
.as_slice()
Expand Down Expand Up @@ -236,7 +233,7 @@ where
&sender,
&mut self.protocol_store.clone(),
&mut self.protocol_store.clone(),
&mut self.csprng,
csprng,
)
.await?
.as_slice()
Expand Down Expand Up @@ -319,18 +316,19 @@ where
}

#[tracing::instrument(
skip(address, unidentified_access, content),
skip(address, unidentified_access, content, csprng),
fields(
address = %address,
with_unidentified_access = unidentified_access.is_some(),
content_length = content.len(),
)
)]
pub(crate) async fn encrypt(
pub(crate) async fn encrypt<R: Rng + CryptoRng>(
&mut self,
address: &ProtocolAddress,
unidentified_access: Option<&SenderCertificate>,
content: &[u8],
csprng: &mut R,
) -> Result<OutgoingPushMessage, ServiceError> {
let session_record = self
.protocol_store
Expand All @@ -354,7 +352,7 @@ where
&mut self.protocol_store.clone(),
&mut self.protocol_store,
SystemTime::now(),
&mut self.csprng,
csprng,
)
.await?;

Expand Down
Loading

0 comments on commit 839a99a

Please sign in to comment.