Skip to content

Commit

Permalink
Switch to reqwest (#332)
Browse files Browse the repository at this point in the history
Also:

* Switch to reqwest and reqwest-websocket
* Use PushService instead of websocket like in libsignal.kt
* Introduce trait to share error handling between PushService and WebSocketService
  • Loading branch information
gferon authored Oct 18, 2024
1 parent 026d751 commit 9782968
Show file tree
Hide file tree
Showing 24 changed files with 960 additions and 1,350 deletions.
16 changes: 4 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,14 @@ url = { version = "2.1", features = ["serde"] }
uuid = { version = "1", features = ["serde"] }

# http
hyper = "1.0"
hyper-util = { version = "0.1", features = ["client", "client-legacy"] }
hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "ring", "logging"] }
hyper-timeout = "0.5"
headers = "0.4"
http-body-util = "0.1"
mpart-async = "0.7"
async-tungstenite = { version = "0.27", features = ["tokio-rustls-native-certs", "url"] }
tokio = { version = "1.0", features = ["macros"] }
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] }

rustls-pemfile = "2.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls-manual-roots", "stream"] }
reqwest-websocket = { version = "0.4.2", features = ["json"] }

tracing = { version = "0.1", features = ["log"] }
tracing-futures = "0.2"

tokio = { version = "1.0", features = ["macros"] }

[build-dependencies]
prost-build = "0.13"

Expand Down
83 changes: 50 additions & 33 deletions src/account_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use base64::prelude::*;
use phonenumber::PhoneNumber;
use reqwest::Method;
use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};

Expand Down Expand Up @@ -28,9 +29,9 @@ use crate::proto::sync_message::PniChangeNumber;
use crate::proto::{DeviceName, SyncMessage};
use crate::provisioning::generate_registration_id;
use crate::push_service::{
AvatarWrite, DeviceActivationRequest, DeviceInfo, RecaptchaAttributes,
RegistrationMethod, ServiceIdType, VerifyAccountResponse,
DEFAULT_DEVICE_ID,
AvatarWrite, DeviceActivationRequest, DeviceInfo, HttpAuthOverride,
RecaptchaAttributes, RegistrationMethod, ReqwestExt, ServiceIdType,
VerifyAccountResponse, DEFAULT_DEVICE_ID,
};
use crate::sender::OutgoingPushMessage;
use crate::session_store::SessionStoreExt;
Expand All @@ -44,9 +45,7 @@ use crate::{
profile_name::ProfileName,
proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion},
provisioning::{ProvisioningCipher, ProvisioningError},
push_service::{
AccountAttributes, HttpAuthOverride, PushService, ServiceError,
},
push_service::{AccountAttributes, PushService, ServiceError},
utils::serde_base64,
};

Expand Down Expand Up @@ -224,13 +223,19 @@ impl AccountManager {

let dc: DeviceCode = self
.service
.get_json(
.request(
Method::GET,
Endpoint::Service,
"/v1/devices/provisioning/code",
&[],
HttpAuthOverride::NoOverride,
)
)?
.send()
.await?
.service_error_for_status()
.await?
.json()
.await?;

Ok(dc.verification_code)
}

Expand All @@ -247,16 +252,21 @@ impl AccountManager {
let body = env.encode_to_vec();

self.service
.put_json(
.request(
Method::PUT,
Endpoint::Service,
&format!("/v1/provisioning/{}", destination),
&[],
format!("/v1/provisioning/{}", destination),
HttpAuthOverride::NoOverride,
&ProvisioningMessage {
body: BASE64_RELAXED.encode(body),
},
)
.await
)?
.json(&ProvisioningMessage {
body: BASE64_RELAXED.encode(body),
})
.send()
.await?
.service_error_for_status()
.await?;

Ok(())
}

/// Link a new device, given a tsurl.
Expand Down Expand Up @@ -582,15 +592,18 @@ impl AccountManager {
}

self.service
.put_json::<(), _>(
.request(
Method::PUT,
Endpoint::Service,
"/v1/accounts/name",
&[],
HttpAuthOverride::NoOverride,
Data {
device_name: encrypted_device_name.encode_to_vec(),
},
)
)?
.json(&Data {
device_name: encrypted_device_name.encode_to_vec(),
})
.send()
.await?
.service_error_for_status()
.await?;

Ok(())
Expand All @@ -607,20 +620,24 @@ impl AccountManager {
token: &str,
captcha: &str,
) -> Result<(), ServiceError> {
let payload = RecaptchaAttributes {
r#type: String::from("recaptcha"),
token: String::from(token),
captcha: String::from(captcha),
};
self.service
.put_json(
.request(
Method::PUT,
Endpoint::Service,
"/v1/challenge",
&[],
HttpAuthOverride::NoOverride,
payload,
)
.await
)?
.json(&RecaptchaAttributes {
r#type: String::from("recaptcha"),
token: String::from(token),
captcha: String::from(captcha),
})
.send()
.await?
.service_error_for_status()
.await?;

Ok(())
}

/// Initialize PNI on linked devices.
Expand Down
26 changes: 8 additions & 18 deletions src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ where
let ciphertext = if let Some(msg) = envelope.content.as_ref() {
msg
} else {
return Err(ServiceError::InvalidFrameError {
return Err(ServiceError::InvalidFrame {
reason:
"Envelope should have either a legacy message or content."
.into(),
"envelope should have either a legacy message or content.",
});
};

Expand Down Expand Up @@ -311,11 +310,8 @@ where
},
_ => {
// else
return Err(ServiceError::InvalidFrameError {
reason: format!(
"Envelope has unknown type {:?}.",
envelope.r#type()
),
return Err(ServiceError::InvalidFrame {
reason: "envelope has unknown type",
});
},
};
Expand Down Expand Up @@ -408,9 +404,7 @@ struct Plaintext {
#[allow(clippy::comparison_chain)]
fn add_padding(version: u32, contents: &[u8]) -> Result<Vec<u8>, ServiceError> {
if version < 2 {
Err(ServiceError::InvalidFrameError {
reason: format!("Unknown version {}", version),
})
Err(ServiceError::PaddingVersion(version))
} else if version == 2 {
Ok(contents.to_vec())
} else {
Expand All @@ -436,8 +430,8 @@ fn strip_padding_version(
contents: &mut Vec<u8>,
) -> Result<(), ServiceError> {
if version < 2 {
Err(ServiceError::InvalidFrameError {
reason: format!("Unknown version {}", version),
Err(ServiceError::InvalidFrame {
reason: "unknown version",
})
} else if version == 2 {
Ok(())
Expand All @@ -449,11 +443,7 @@ fn strip_padding_version(

#[allow(clippy::comparison_chain)]
fn strip_padding(contents: &mut Vec<u8>) -> Result<(), ServiceError> {
let new_length = Iso7816::raw_unpad(contents)
.map_err(|e| ServiceError::InvalidFrameError {
reason: format!("Invalid message padding: {:?}", e),
})?
.len();
let new_length = Iso7816::raw_unpad(contents)?.len();
contents.resize(new_length, 0);
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ impl Envelope {
if input.len() < VERSION_LENGTH
|| input[VERSION_OFFSET] != SUPPORTED_VERSION
{
return Err(ServiceError::InvalidFrameError {
reason: "Unsupported signaling cryptogram version".into(),
return Err(ServiceError::InvalidFrame {
reason: "unsupported signaling cryptogram version",
});
}

Expand Down
32 changes: 17 additions & 15 deletions src/groups_v2/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::{collections::HashMap, convert::TryInto};

use crate::{
configuration::Endpoint,
groups_v2::model::{Group, GroupChanges},
groups_v2::operations::{GroupDecodingError, GroupOperations},
groups_v2::{
model::{Group, GroupChanges},
operations::{GroupDecodingError, GroupOperations},
},
prelude::{PushService, ServiceError},
proto::GroupContextV2,
push_service::{HttpAuth, HttpAuthOverride, ServiceIds},
push_service::{HttpAuth, HttpAuthOverride, ReqwestExt, ServiceIds},
utils::BASE64_RELAXED,
};

Expand All @@ -15,6 +17,7 @@ use bytes::Bytes;
use chrono::{Days, NaiveDate, NaiveTime, Utc};
use futures::AsyncReadExt;
use rand::RngCore;
use reqwest::Method;
use serde::Deserialize;
use zkgroup::{
auth::AuthCredentialWithPniResponse,
Expand Down Expand Up @@ -165,20 +168,24 @@ impl<C: CredentialsCache> GroupsManager<C> {

let credentials_response: CredentialResponse = self
.push_service
.get_json(
.request(
Method::GET,
Endpoint::Service,
&path,
&[],
HttpAuthOverride::NoOverride,
)
)?
.send()
.await?
.service_error_for_status()
.await?
.json()
.await?;
self.credentials_cache
.write(credentials_response.parse()?)?;
self.credentials_cache.get(&today)?.ok_or_else(|| {
ServiceError::ResponseError {
ServiceError::InvalidFrame {
reason:
"credentials received did not contain requested day"
.into(),
"credentials received did not contain requested day",
}
})?
};
Expand Down Expand Up @@ -279,12 +286,7 @@ impl<C: CredentialsCache> GroupsManager<C> {
.retrieve_groups_v2_profile_avatar(path)
.await?;
let mut result = Vec::with_capacity(10 * 1024 * 1024);
encrypted_avatar
.read_to_end(&mut result)
.await
.map_err(|e| ServiceError::ResponseError {
reason: format!("reading avatar data: {}", e),
})?;
encrypted_avatar.read_to_end(&mut result).await?;
Ok(GroupOperations::new(group_secret_params).decrypt_avatar(&result))
}

Expand Down
33 changes: 3 additions & 30 deletions src/messagepipe.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use bytes::Bytes;
use futures::{
channel::{
mpsc::{self, Sender},
oneshot,
},
prelude::*,
stream::FusedStream,
};

pub use crate::{
Expand All @@ -18,24 +16,12 @@ pub use crate::{

use crate::{push_service::ServiceError, websocket::SignalWebSocket};

pub enum WebSocketStreamItem {
Message(Bytes),
KeepAliveRequest,
}

#[derive(Debug)]
pub enum Incoming {
Envelope(Envelope),
QueueEmpty,
}

#[async_trait::async_trait]
pub trait WebSocketService {
type Stream: FusedStream<Item = WebSocketStreamItem> + Unpin;

async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError>;
}

pub struct MessagePipe {
ws: SignalWebSocket,
credentials: ServiceCredentials,
Expand Down Expand Up @@ -93,8 +79,8 @@ impl MessagePipe {
let body = if let Some(body) = request.body.as_ref() {
body
} else {
return Err(ServiceError::InvalidFrameError {
reason: "Request without body.".into(),
return Err(ServiceError::InvalidFrame {
reason: "request without body.",
});
};
Some(Incoming::Envelope(Envelope::decrypt(
Expand All @@ -111,7 +97,7 @@ impl MessagePipe {
responder
.send(response)
.map_err(|_| ServiceError::WsClosing {
reason: "could not respond to message pipe request".into(),
reason: "could not respond to message pipe request",
})?;

Ok(result)
Expand All @@ -133,16 +119,3 @@ impl MessagePipe {
combined.filter_map(|x| async { x })
}
}

/// WebSocketService that panics on every request, mainly for example code.
pub struct PanicingWebSocketService;

#[allow(clippy::diverging_sub_expression)]
#[async_trait::async_trait]
impl WebSocketService for PanicingWebSocketService {
type Stream = futures::channel::mpsc::Receiver<WebSocketStreamItem>;

async fn send_message(&mut self, _msg: Bytes) -> Result<(), ServiceError> {
todo!();
}
}
Loading

0 comments on commit 9782968

Please sign in to comment.