From c50bb9896f68b50d98b491fc5bdf286790b1eba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Wed, 16 Oct 2024 13:50:50 +0200 Subject: [PATCH] Switch to reqwest and reqwest-websocket --- Cargo.toml | 16 +- src/account_manager.rs | 73 +++-- src/groups_v2/manager.rs | 18 +- src/messagepipe.rs | 27 -- src/push_service/account.rs | 47 +-- src/push_service/cdn.rs | 209 ++++-------- src/push_service/error.rs | 9 +- src/push_service/keys.rs | 113 ++++--- src/push_service/linking.rs | 27 +- src/push_service/mod.rs | 544 ++++++++----------------------- src/push_service/profile.rs | 44 +-- src/push_service/registration.rs | 166 +++++----- src/sender.rs | 3 +- src/websocket/mod.rs | 125 ++++--- src/websocket/tungstenite.rs | 15 +- 15 files changed, 579 insertions(+), 857 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c3213bb4e..cc43f204a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,22 +35,14 @@ url = { version = "2.1", features = ["serde"] } uuid = { version = "1", features = ["serde"] } # http -hyper = "1.0" -hyper-util = { version = "0.1", features = ["client", "client-legacy"] } -hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "ring", "logging"] } -hyper-timeout = "0.5" -headers = "0.4" -http-body-util = "0.1" -mpart-async = "0.7" -async-tungstenite = { version = "0.27", features = ["tokio-rustls-native-certs", "url"] } -tokio = { version = "1.0", features = ["macros"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } - -rustls-pemfile = "2.0" +reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls-manual-roots", "stream"] } +reqwest-websocket = { version = "0.4.2", features = ["json"] } tracing = { version = "0.1", features = ["log"] } tracing-futures = "0.2" +tokio = { version = "1.0", features = ["macros"] } + [build-dependencies] prost-build = "0.13" diff --git a/src/account_manager.rs b/src/account_manager.rs index d8251c92b..84ade2573 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -1,5 +1,6 @@ use base64::prelude::*; use phonenumber::PhoneNumber; +use reqwest::Method; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; @@ -28,9 +29,9 @@ use crate::proto::sync_message::PniChangeNumber; use crate::proto::{DeviceName, SyncMessage}; use crate::provisioning::generate_registration_id; use crate::push_service::{ - AvatarWrite, DeviceActivationRequest, DeviceInfo, RecaptchaAttributes, - RegistrationMethod, ServiceIdType, VerifyAccountResponse, - DEFAULT_DEVICE_ID, + AvatarWrite, DeviceActivationRequest, DeviceInfo, HttpAuthOverride, + RecaptchaAttributes, RegistrationMethod, ReqwestExt, ServiceIdType, + VerifyAccountResponse, DEFAULT_DEVICE_ID, }; use crate::sender::OutgoingPushMessage; use crate::session_store::SessionStoreExt; @@ -44,9 +45,7 @@ use crate::{ profile_name::ProfileName, proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion}, provisioning::{ProvisioningCipher, ProvisioningError}, - push_service::{ - AccountAttributes, HttpAuthOverride, PushService, ServiceError, - }, + push_service::{AccountAttributes, PushService, ServiceError}, utils::serde_base64, }; @@ -224,13 +223,17 @@ impl AccountManager { let dc: DeviceCode = self .service - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/devices/provisioning/code", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; + Ok(dc.verification_code) } @@ -247,16 +250,19 @@ impl AccountManager { let body = env.encode_to_vec(); self.service - .put_json( + .request( + Method::PUT, Endpoint::Service, &format!("/v1/provisioning/{}", destination), - &[], HttpAuthOverride::NoOverride, - &ProvisioningMessage { - body: BASE64_RELAXED.encode(body), - }, - ) - .await + )? + .json(&ProvisioningMessage { + body: BASE64_RELAXED.encode(body), + }) + .send_to_signal() + .await?; + + Ok(()) } /// Link a new device, given a tsurl. @@ -582,15 +588,16 @@ impl AccountManager { } self.service - .put_json::<(), _>( + .request( + Method::PUT, Endpoint::Service, "/v1/accounts/name", - &[], HttpAuthOverride::NoOverride, - Data { - device_name: encrypted_device_name.encode_to_vec(), - }, - ) + )? + .json(&Data { + device_name: encrypted_device_name.encode_to_vec(), + }) + .send_to_signal() .await?; Ok(()) @@ -607,20 +614,22 @@ impl AccountManager { token: &str, captcha: &str, ) -> Result<(), ServiceError> { - let payload = RecaptchaAttributes { - r#type: String::from("recaptcha"), - token: String::from(token), - captcha: String::from(captcha), - }; self.service - .put_json( + .request( + Method::PUT, Endpoint::Service, "/v1/challenge", - &[], HttpAuthOverride::NoOverride, - payload, - ) - .await + )? + .json(&RecaptchaAttributes { + r#type: String::from("recaptcha"), + token: String::from(token), + captcha: String::from(captcha), + }) + .send_to_signal() + .await?; + + Ok(()) } /// Initialize PNI on linked devices. diff --git a/src/groups_v2/manager.rs b/src/groups_v2/manager.rs index d5e77dffd..9183bef10 100644 --- a/src/groups_v2/manager.rs +++ b/src/groups_v2/manager.rs @@ -2,11 +2,13 @@ use std::{collections::HashMap, convert::TryInto}; use crate::{ configuration::Endpoint, - groups_v2::model::{Group, GroupChanges}, - groups_v2::operations::{GroupDecodingError, GroupOperations}, + groups_v2::{ + model::{Group, GroupChanges}, + operations::{GroupDecodingError, GroupOperations}, + }, prelude::{PushService, ServiceError}, proto::GroupContextV2, - push_service::{HttpAuth, HttpAuthOverride, ServiceIds}, + push_service::{HttpAuth, HttpAuthOverride, ReqwestExt, ServiceIds}, utils::BASE64_RELAXED, }; @@ -15,6 +17,7 @@ use bytes::Bytes; use chrono::{Days, NaiveDate, NaiveTime, Utc}; use futures::AsyncReadExt; use rand::RngCore; +use reqwest::Method; use serde::Deserialize; use zkgroup::{ auth::AuthCredentialWithPniResponse, @@ -165,12 +168,15 @@ impl GroupsManager { let credentials_response: CredentialResponse = self .push_service - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; self.credentials_cache .write(credentials_response.parse()?)?; diff --git a/src/messagepipe.rs b/src/messagepipe.rs index cdb77b7a8..f3e0c67c5 100644 --- a/src/messagepipe.rs +++ b/src/messagepipe.rs @@ -1,11 +1,9 @@ -use bytes::Bytes; use futures::{ channel::{ mpsc::{self, Sender}, oneshot, }, prelude::*, - stream::FusedStream, }; pub use crate::{ @@ -18,24 +16,12 @@ pub use crate::{ use crate::{push_service::ServiceError, websocket::SignalWebSocket}; -pub enum WebSocketStreamItem { - Message(Bytes), - KeepAliveRequest, -} - #[derive(Debug)] pub enum Incoming { Envelope(Envelope), QueueEmpty, } -#[async_trait::async_trait] -pub trait WebSocketService { - type Stream: FusedStream + Unpin; - - async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError>; -} - pub struct MessagePipe { ws: SignalWebSocket, credentials: ServiceCredentials, @@ -133,16 +119,3 @@ impl MessagePipe { combined.filter_map(|x| async { x }) } } - -/// WebSocketService that panics on every request, mainly for example code. -pub struct PanicingWebSocketService; - -#[allow(clippy::diverging_sub_expression)] -#[async_trait::async_trait] -impl WebSocketService for PanicingWebSocketService { - type Stream = futures::channel::mpsc::Receiver; - - async fn send_message(&mut self, _msg: Bytes) -> Result<(), ServiceError> { - todo!(); - } -} diff --git a/src/push_service/account.rs b/src/push_service/account.rs index d36eff544..293e0857e 100644 --- a/src/push_service/account.rs +++ b/src/push_service/account.rs @@ -2,10 +2,11 @@ use std::fmt; use chrono::{DateTime, Utc}; use phonenumber::PhoneNumber; +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{HttpAuthOverride, PushService, ServiceError}; +use super::{HttpAuthOverride, PushService, ReqwestExt, ServiceError}; use crate::{ configuration::Endpoint, utils::{serde_optional_base64, serde_phone_number}, @@ -127,13 +128,17 @@ pub struct WhoAmIResponse { impl PushService { /// Method used to check our own UUID pub async fn whoami(&mut self) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, "/v1/accounts/whoami", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } /// Fetches a list of all devices tied to the authenticated account. @@ -146,12 +151,15 @@ impl PushService { } let devices: DeviceInfoList = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/devices/", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; Ok(devices.devices) @@ -166,18 +174,17 @@ impl PushService { "only one of PIN and registration lock can be set." ); - match self - .put_json( - Endpoint::Service, - "/v1/accounts/attributes/", - &[], - HttpAuthOverride::NoOverride, - attributes, - ) - .await - { - Err(ServiceError::JsonDecodeError { .. }) => Ok(()), - r => r, - } + self.request( + Method::PUT, + Endpoint::Service, + "/v1/accounts/attributes/", + HttpAuthOverride::NoOverride, + )? + .json(&attributes) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index 3ba962a6d..8309807b2 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -1,19 +1,15 @@ use std::io::{self, Read}; -use bytes::Bytes; -use futures::{FutureExt, StreamExt, TryStreamExt}; -use http_body_util::BodyExt; -use hyper::Method; +use futures::TryStreamExt; +use reqwest::{multipart::Part, Method}; use tracing::debug; use crate::{ - configuration::Endpoint, - prelude::AttachmentIdentifier, - proto::AttachmentPointer, - push_service::{HttpAuthOverride, RequestBody}, + configuration::Endpoint, prelude::AttachmentIdentifier, + proto::AttachmentPointer, push_service::HttpAuthOverride, }; -use super::{PushService, ServiceError}; +use super::{PushService, ReqwestExt, ServiceError}; #[derive(Debug, serde::Deserialize, Default)] #[serde(rename_all = "camelCase")] @@ -32,170 +28,109 @@ pub struct AttachmentV2UploadAttributes { } impl PushService { + pub async fn get_attachment( + &mut self, + ptr: &AttachmentPointer, + ) -> Result { + let id = match ptr.attachment_identifier.as_ref().unwrap() { + AttachmentIdentifier::CdnId(id) => &id.to_string(), + AttachmentIdentifier::CdnKey(key) => key, + }; + self.get_from_cdn(ptr.cdn_number(), &format!("attachments/{}", id)) + .await + } + #[tracing::instrument(skip(self))] pub(crate) async fn get_from_cdn( &mut self, cdn_id: u32, path: &str, ) -> Result { - let response = self + let response_stream = self .request( Method::GET, Endpoint::Cdn(cdn_id), path, - &[], HttpAuthOverride::Unidentified, // CDN requests are always without authentication - None, - ) - .await?; - - Ok(Box::new( - response - .into_body() - .into_data_stream() - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - .into_async_read(), - )) - } - - pub async fn get_attachment_by_id( - &mut self, - id: &str, - cdn_id: u32, - ) -> Result { - let path = format!("attachments/{}", id); - self.get_from_cdn(cdn_id, &path).await + )? + .send_to_signal() + .await? + .bytes_stream() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + .into_async_read(); + + Ok(response_stream) } - pub async fn get_attachment( - &mut self, - ptr: &AttachmentPointer, - ) -> Result { - match ptr.attachment_identifier.as_ref().unwrap() { - AttachmentIdentifier::CdnId(id) => { - // cdn_number did not exist for this part of the protocol. - // cdn_number(), however, returns 0 when the field does not - // exist. - self.get_attachment_by_id(&format!("{}", id), ptr.cdn_number()) - .await - }, - AttachmentIdentifier::CdnKey(key) => { - self.get_attachment_by_id(key, ptr.cdn_number()).await - }, - } - } - - pub async fn get_attachment_v2_upload_attributes( + pub(crate) async fn get_attachment_v2_upload_attributes( &mut self, ) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, "/v2/attachments/form/upload", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } /// Upload attachment to CDN /// /// Returns attachment ID and the attachment digest - pub async fn upload_attachment<'s, C>( + pub async fn upload_attachment( &mut self, - attrs: &AttachmentV2UploadAttributes, - content: &'s mut C, - ) -> Result<(u64, Vec), ServiceError> - where - C: std::io::Read + Send + 's, - { - let values = [ - ("acl", &attrs.acl as &str), - ("key", &attrs.key), - ("policy", &attrs.policy), - ("Content-Type", "application/octet-stream"), - ("x-amz-algorithm", &attrs.algorithm), - ("x-amz-credential", &attrs.credential), - ("x-amz-date", &attrs.date), - ("x-amz-signature", &attrs.signature), - ]; - - let mut digester = crate::digeststream::DigestingReader::new(content); - - self.post_to_cdn0( - "attachments/", - &values, - Some(("file", &mut digester)), - ) - .await?; - Ok((attrs.attachment_id, digester.finalize())) + attrs: AttachmentV2UploadAttributes, + mut reader: impl Read + Send, + ) -> Result<(u64, Vec), ServiceError> { + let attachment_id = attrs.attachment_id; + let mut digester = + crate::digeststream::DigestingReader::new(&mut reader); + + self.post_to_cdn0("attachments/", attrs, "file".into(), &mut digester) + .await?; + + Ok((attachment_id, digester.finalize())) } - #[tracing::instrument(skip(self, value, file), fields(file = file.as_ref().map(|_| "")))] - pub async fn post_to_cdn0<'s, C>( + #[tracing::instrument(skip(self, upload_attributes, reader))] + pub async fn post_to_cdn0( &mut self, path: &str, - value: &[(&str, &str)], - file: Option<(&str, &'s mut C)>, - ) -> Result<(), ServiceError> - where - C: Read + Send + 's, - { - let mut form = mpart_async::client::MultipartRequest::default(); - - // mpart-async has a peculiar ordering of the form items, - // and Amazon S3 expects them in a very specific order (i.e., the file contents should - // go last. - // - // mpart-async uses a VecDeque internally for ordering the fields in the order given. - // - // https://github.com/cetra3/mpart-async/issues/16 - - for &(k, v) in value { - form.add_field(k, v); - } - - if let Some((filename, file)) = file { - // XXX Actix doesn't cope with none-'static lifetimes - // https://docs.rs/actix-web/3.2.0/actix_web/body/enum.Body.html - let mut buf = Vec::new(); - file.read_to_end(&mut buf) - .expect("infallible Read instance"); - form.add_stream( - "file", - filename, - "application/octet-stream", - futures::future::ok::<_, ()>(Bytes::from(buf)).into_stream(), - ); - } - - let content_type = - format!("multipart/form-data; boundary={}", form.get_boundary()); - - // XXX Amazon S3 needs the Content-Length, but we don't know it without depleting the whole - // stream. Sadly, Content-Length != contents.len(), but should include the whole form. - let mut body_contents = vec![]; - while let Some(b) = form.next().await { - // Unwrap, because no error type was used above - body_contents.extend(b.unwrap()); - } - tracing::trace!( - "Sending PUT with Content-Type={} and length {}", - content_type, - body_contents.len() - ); + upload_attributes: AttachmentV2UploadAttributes, + filename: String, + mut reader: impl Read + Send, + ) -> Result<(), ServiceError> { + let mut form = reqwest::multipart::Form::new(); + form = form.text("acl", upload_attributes.acl); + form = form.text("key", upload_attributes.key); + form = form.text("policy", upload_attributes.policy); + form = form.text("x-amz-algorithm", upload_attributes.algorithm); + form = form.text("x-amz-credential", upload_attributes.credential); + form = form.text("x-amz-date", upload_attributes.date); + form = form.text("x-amz-signature", upload_attributes.signature); + + let mut buf = Vec::new(); + reader + .read_to_end(&mut buf) + .expect("infallible Read instance"); + + form = form.text("Content-Type", "application/octet-stream"); + form = form.text("Content-Length", buf.len().to_string()); + form = form.part("file", Part::bytes(buf).file_name(filename)); let response = self .request( Method::POST, Endpoint::Cdn(0), path, - &[], HttpAuthOverride::NoOverride, - Some(RequestBody { - contents: body_contents, - content_type, - }), - ) + )? + .multipart(form) + .send_to_signal() .await?; debug!("HyperPushService::PUT response: {:?}", response); diff --git a/src/push_service/error.rs b/src/push_service/error.rs index d197667d9..2e96f27c7 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -39,10 +39,10 @@ pub enum ServiceError { #[error("Unexpected response: HTTP {http_code}")] UnhandledResponseCode { http_code: u16 }, - #[error("Websocket error: {reason}")] - WsError { reason: String }, + #[error("Websocket error: {0}")] + WsError(#[from] reqwest_websocket::Error), #[error("Websocket closing: {reason}")] - WsClosing { reason: String }, + WsClosing { reason: &'static str }, #[error("Invalid frame: {reason}")] InvalidFrameError { reason: String }, @@ -85,4 +85,7 @@ pub enum ServiceError { #[error("invalid device name")] InvalidDeviceName, + + #[error("HTTP reqwest error: {0}")] + Http(#[from] reqwest::Error), } diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index 50894d1f2..93b7d6636 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use libsignal_protocol::{IdentityKey, PreKeyBundle, SenderCertificate}; +use reqwest::Method; use serde::Deserialize; use crate::{ @@ -13,8 +14,8 @@ use crate::{ }; use super::{ - HttpAuthOverride, PushService, SenderCertificateJson, ServiceError, - ServiceIdType, VerifyAccountResponse, + HttpAuthOverride, PushService, ReqwestExt, SenderCertificateJson, + ServiceError, ServiceIdType, VerifyAccountResponse, }; #[derive(Debug, Deserialize, Default)] @@ -29,13 +30,17 @@ impl PushService { &mut self, service_id_type: ServiceIdType, ) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, &format!("/v2/keys?identity={}", service_id_type), - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } pub async fn register_pre_keys( @@ -43,19 +48,17 @@ impl PushService { service_id_type: ServiceIdType, pre_key_state: PreKeyState, ) -> Result<(), ServiceError> { - match self - .put_json( - Endpoint::Service, - &format!("/v2/keys?identity={}", service_id_type), - &[], - HttpAuthOverride::NoOverride, - pre_key_state, - ) - .await - { - Err(ServiceError::JsonDecodeError { .. }) => Ok(()), - r => r, - } + self.request( + Method::PUT, + Endpoint::Service, + &format!("/v2/keys?identity={}", service_id_type), + HttpAuthOverride::NoOverride, + )? + .json(&pre_key_state) + .send_to_signal() + .await?; + + Ok(()) } pub async fn get_pre_key( @@ -67,13 +70,17 @@ impl PushService { format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id); let mut pre_key_response: PreKeyResponse = self - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; + assert!(!pre_key_response.devices.is_empty()); let identity = IdentityKey::decode(&pre_key_response.identity_key)?; @@ -92,12 +99,15 @@ impl PushService { format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id) }; let pre_key_response: PreKeyResponse = self - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; let mut pre_keys = vec![]; let identity = IdentityKey::decode(&pre_key_response.identity_key)?; @@ -111,12 +121,15 @@ impl PushService { &mut self, ) -> Result { let cert: SenderCertificateJson = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/certificate/delivery", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; Ok(SenderCertificate::deserialize(&cert.certificate)?) } @@ -125,12 +138,15 @@ impl PushService { &mut self, ) -> Result { let cert: SenderCertificateJson = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/certificate/delivery?includeE164=false", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; Ok(SenderCertificate::deserialize(&cert.certificate)?) } @@ -160,23 +176,24 @@ impl PushService { pni_registration_ids: HashMap, signature_valid_on_each_signed_pre_key: bool, } - - let res: VerifyAccountResponse = self - .put_json( - Endpoint::Service, - "/v2/accounts/phone_number_identity_key_distribution", - &[], - HttpAuthOverride::NoOverride, - PniKeyDistributionRequest { - pni_identity_key: pni_identity_key.serialize().into(), - device_messages, - device_pni_signed_prekeys, - device_pni_last_resort_kyber_prekeys, - pni_registration_ids, - signature_valid_on_each_signed_pre_key, - }, - ) - .await?; - Ok(res) + self.request( + Method::PUT, + Endpoint::Service, + "/v2/accounts/phone_number_identity_key_distribution", + HttpAuthOverride::NoOverride, + )? + .json(&PniKeyDistributionRequest { + pni_identity_key: pni_identity_key.serialize().into(), + device_messages, + device_pni_signed_prekeys, + device_pni_last_resort_kyber_prekeys, + pni_registration_ids, + signature_valid_on_each_signed_pre_key, + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/push_service/linking.rs b/src/push_service/linking.rs index 5d9b9ded8..d10f20116 100644 --- a/src/push_service/linking.rs +++ b/src/push_service/linking.rs @@ -1,3 +1,4 @@ +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -5,7 +6,7 @@ use crate::configuration::Endpoint; use super::{ DeviceActivationRequest, HttpAuth, HttpAuthOverride, PushService, - ServiceError, + ReqwestExt, ServiceError, }; #[derive(Debug, Serialize)] @@ -59,18 +60,30 @@ impl PushService { link_request: &LinkRequest, http_auth: HttpAuth, ) -> Result { - self.put_json( + self.request( + Method::PUT, Endpoint::Service, "/v1/devices/link", - &[], HttpAuthOverride::Identified(http_auth), - link_request, - ) + )? + .json(&link_request) + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } pub async fn unlink_device(&mut self, id: i64) -> Result<(), ServiceError> { - self.delete_json(Endpoint::Service, &format!("/v1/devices/{}", id), &[]) - .await + self.request( + Method::DELETE, + Endpoint::Service, + format!("/v1/devices/{}", id), + HttpAuthOverride::NoOverride, + )? + .send_to_signal() + .await?; + + Ok(()) } } diff --git a/src/push_service/mod.rs b/src/push_service/mod.rs index f839f185b..daf683617 100644 --- a/src/push_service/mod.rs +++ b/src/push_service/mod.rs @@ -1,36 +1,23 @@ -use std::{io, time::Duration}; +use std::time::Duration; use crate::{ configuration::{Endpoint, ServiceCredentials}, pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity}, prelude::ServiceConfiguration, utils::serde_base64, - websocket::{tungstenite::TungsteniteWebSocket, SignalWebSocket}, + websocket::SignalWebSocket, }; -use bytes::{Buf, Bytes}; use derivative::Derivative; -use headers::{Authorization, HeaderMapExt}; -use http_body_util::{BodyExt, Full}; -use hyper::{ - body::Incoming, - header::{CONTENT_LENGTH, CONTENT_TYPE, USER_AGENT}, - Method, Request, Response, StatusCode, -}; -use hyper_rustls::HttpsConnector; -use hyper_timeout::TimeoutConnector; -use hyper_util::{ - client::legacy::{connect::HttpConnector, Client}, - rt::TokioExecutor, -}; use libsignal_protocol::{ error::SignalProtocolError, kem::{Key, Public}, IdentityKey, PreKeyBundle, PublicKey, }; -use prost::Message as ProtobufMessage; +use protobuf::ProtobufResponseExt; +use reqwest::{Method, RequestBuilder, Response, StatusCode}; +use reqwest_websocket::RequestBuilderExt; use serde::{Deserialize, Serialize}; -use tokio_rustls::rustls; use tracing::{debug_span, Instrument}; pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55); @@ -67,16 +54,6 @@ pub struct HttpAuth { pub password: String, } -/// This type is used in registration lock handling. -/// It's identical with HttpAuth, but used to avoid type confusion. -#[derive(Derivative, Clone, Serialize, Deserialize)] -#[derivative(Debug)] -pub struct AuthCredentials { - pub username: String, - #[derivative(Debug = "ignore")] - pub password: String, -} - #[derive(Debug, Clone)] pub enum HttpAuthOverride { NoOverride, @@ -164,127 +141,83 @@ pub struct StaleDevices { pub stale_devices: Vec, } -#[derive(Debug)] -struct RequestBody { - contents: Vec, - content_type: String, -} - #[derive(Clone)] pub struct PushService { cfg: ServiceConfiguration, - user_agent: String, credentials: Option, - client: - Client>, Full>, + client: reqwest::Client, } impl PushService { pub fn new( cfg: impl Into, credentials: Option, - user_agent: String, + user_agent: impl AsRef, ) -> Self { let cfg = cfg.into(); - let tls_config = Self::tls_config(&cfg); - - let https = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(tls_config) - .https_only() - .enable_http1() - .build(); - - // as in Signal-Android - let mut timeout_connector = TimeoutConnector::new(https); - timeout_connector.set_connect_timeout(Some(Duration::from_secs(10))); - timeout_connector.set_read_timeout(Some(Duration::from_secs(65))); - timeout_connector.set_write_timeout(Some(Duration::from_secs(65))); - - let client: Client<_, Full> = - Client::builder(TokioExecutor::new()).build(timeout_connector); + let client = reqwest::ClientBuilder::new() + .add_root_certificate( + reqwest::Certificate::from_pem( + &cfg.certificate_authority.as_bytes(), + ) + .unwrap(), + ) + .connect_timeout(Duration::from_secs(10)) + .timeout(Duration::from_secs(65)) + .user_agent(user_agent.as_ref()) + .build() + .unwrap(); Self { cfg, credentials: credentials.and_then(|c| c.authorization()), client, - user_agent, } } - fn tls_config(cfg: &ServiceConfiguration) -> rustls::ClientConfig { - let mut cert_bytes = io::Cursor::new(&cfg.certificate_authority); - let roots = rustls_pemfile::certs(&mut cert_bytes); - - let mut root_certs = rustls::RootCertStore::empty(); - root_certs.add_parsable_certificates( - roots.map(|c| c.expect("parsable PEM files")), - ); - - rustls::ClientConfig::builder() - .with_root_certificates(root_certs) - .with_no_client_auth() - } - - #[tracing::instrument(skip(self, path, body), fields(path = %path.as_ref()))] - async fn request( + #[tracing::instrument(skip(self, path), fields(path = %path.as_ref()))] + pub fn request( &self, method: Method, endpoint: Endpoint, path: impl AsRef, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - body: Option, - ) -> Result, ServiceError> { + auth_override: HttpAuthOverride, + ) -> Result { let url = self.cfg.base_url(endpoint).join(path.as_ref())?; - let mut builder = Request::builder() - .method(method) - .uri(url.as_str()) - .header(USER_AGENT, &self.user_agent); - - for (header, value) in additional_headers { - builder = builder.header(*header, *value); - } + let mut builder = self.client.request(method, url); - match credentials_override { + builder = match auth_override { HttpAuthOverride::NoOverride => { if let Some(HttpAuth { username, password }) = self.credentials.as_ref() { + builder.basic_auth(username, Some(password)) + } else { builder - .headers_mut() - .unwrap() - .typed_insert(Authorization::basic(username, password)); } }, HttpAuthOverride::Identified(HttpAuth { username, password }) => { - builder - .headers_mut() - .unwrap() - .typed_insert(Authorization::basic(&username, &password)); + builder.basic_auth(username, Some(password)) }, - HttpAuthOverride::Unidentified => (), + HttpAuthOverride::Unidentified => builder, }; - let request = if let Some(RequestBody { - contents, - content_type, - }) = body - { - builder - .header(CONTENT_LENGTH, contents.len() as u64) - .header(CONTENT_TYPE, content_type) - .body(Full::new(Bytes::from(contents))) - .unwrap() - } else { - builder.body(Full::default()).unwrap() - }; + Ok(builder) + } +} - let mut response = self.client.request(request).await.map_err(|e| { - ServiceError::SendError { - reason: e.to_string(), - } - })?; +#[async_trait::async_trait] +pub(crate) trait ReqwestExt +where + Self: Sized, +{ + async fn send_to_signal(self) -> Result; +} +#[async_trait::async_trait] +impl ReqwestExt for RequestBuilder { + async fn send_to_signal(self) -> Result { + let response = self.send().await?; match response.status() { StatusCode::OK => Ok(response), StatusCode::NO_CONTENT => Ok(response), @@ -301,10 +234,10 @@ impl PushService { }, StatusCode::CONFLICT => { let mismatched_devices = - Self::json(&mut response).await.map_err(|e| { + response.json().await.map_err(|error| { tracing::error!( - "Failed to decode HTTP 409 response: {}", - e + %error, + "failed to decode HTTP 409 status" ); ServiceError::UnhandledResponseCode { http_code: StatusCode::CONFLICT.as_u16(), @@ -315,25 +248,17 @@ impl PushService { )) }, StatusCode::GONE => { - let stale_devices = - Self::json(&mut response).await.map_err(|e| { - tracing::error!( - "Failed to decode HTTP 410 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::GONE.as_u16(), - } - })?; + let stale_devices = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 410 status"); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::GONE.as_u16(), + } + })?; Err(ServiceError::StaleDevices(stale_devices)) }, StatusCode::LOCKED => { - let locked = Self::json(&mut response).await.map_err(|e| { - tracing::error!( - ?response, - "Failed to decode HTTP 423 response: {}", - e - ); + let locked = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 423 status"); ServiceError::UnhandledResponseCode { http_code: StatusCode::LOCKED.as_u16(), } @@ -342,10 +267,10 @@ impl PushService { }, StatusCode::PRECONDITION_REQUIRED => { let proof_required = - Self::json(&mut response).await.map_err(|e| { + response.json().await.map_err(|error| { tracing::error!( - "Failed to decode HTTP 428 response: {}", - e + %error, + "failed to decode HTTP 428 status" ); ServiceError::UnhandledResponseCode { http_code: StatusCode::PRECONDITION_REQUIRED @@ -356,306 +281,107 @@ impl PushService { }, // XXX: fill in rest from PushServiceSocket code => { - tracing::trace!( - "Unhandled response {} with body: {}", - code.as_u16(), - Self::text(&mut response).await?, - ); + let response_text = response.text().await?; + tracing::trace!(status_code =% code, body = response_text, "unhandled HTTP response"); Err(ServiceError::UnhandledResponseCode { http_code: code.as_u16(), }) }, } } - - async fn body( - response: &mut Response, - ) -> Result { - Ok(response - .collect() - .await - .map_err(|e| ServiceError::ResponseError { - reason: format!("failed to aggregate HTTP response body: {e}"), - })? - .aggregate()) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn json( - response: &mut Response, - ) -> Result - where - for<'de> T: Deserialize<'de>, - { - let body = Self::body(response).await?; - - if body.has_remaining() { - serde_json::from_reader(body.reader()) - } else { - serde_json::from_value(serde_json::Value::Null) - } - .map_err(|e| ServiceError::JsonDecodeError { - reason: e.to_string(), - }) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn protobuf( - response: &mut Response, - ) -> Result - where - M: ProtobufMessage + Default, - { - let body = Self::body(response).await?; - M::decode(body).map_err(ServiceError::ProtobufDecodeError) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn text( - response: &mut Response, - ) -> Result { - let body = Self::body(response).await?; - io::read_to_string(body.reader()).map_err(|e| { - ServiceError::ResponseError { - reason: format!("failed to read HTTP response body: {e}"), - } - }) - } } -impl PushService { - #[tracing::instrument(skip(self))] - pub(crate) async fn get_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - ) -> Result - where - for<'de> T: Deserialize<'de>, - { - let mut response = self - .request( - Method::GET, - service, - path, - additional_headers, - credentials_override, - None, - ) - .await?; +pub(crate) mod protobuf { + use async_trait::async_trait; + use prost::{EncodeError, Message}; + use reqwest::{header, RequestBuilder, Response}; - Self::json(&mut response).await - } + use super::ServiceError; - #[tracing::instrument(skip(self))] - async fn delete_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - ) -> Result + pub(crate) trait ProtobufRequestBuilderExt where - for<'de> T: Deserialize<'de>, + Self: Sized, { - let mut response = self - .request( - Method::DELETE, - service, - path, - additional_headers, - HttpAuthOverride::NoOverride, - None, - ) - .await?; - - Self::json(&mut response).await - } - - #[tracing::instrument(skip(self, value))] - pub async fn put_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::PUT, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; - - Self::json(&mut response).await + /// Set the request payload encoded as protobuf. + /// Sets the `Content-Type` header to `application/protobuf` + #[allow(dead_code)] + fn protobuf( + self, + value: T, + ) -> Result; } - #[tracing::instrument(skip(self, value))] - async fn patch_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::PATCH, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; - - Self::json(&mut response).await + #[async_trait::async_trait] + pub(crate) trait ProtobufResponseExt { + /// Get the response body decoded from Protobuf + async fn protobuf( + self, + ) -> Result; } - #[tracing::instrument(skip(self, value))] - async fn post_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::POST, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; - - Self::json(&mut response).await - } - - #[tracing::instrument(skip(self))] - async fn get_protobuf( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - ) -> Result - where - T: Default + ProtobufMessage, - { - let mut response = self - .request( - Method::GET, - service, - path, - additional_headers, - credentials_override, - None, - ) - .await?; - - Self::protobuf(&mut response).await + impl ProtobufRequestBuilderExt for RequestBuilder { + fn protobuf( + self, + value: T, + ) -> Result { + let mut buf = Vec::new(); + value.encode(&mut buf)?; + let this = + self.header(header::CONTENT_TYPE, "application/protobuf"); + Ok(this.body(buf)) + } } - #[tracing::instrument(skip(self, value))] - async fn put_protobuf( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - value: S, - ) -> Result - where - D: Default + ProtobufMessage, - S: Sized + ProtobufMessage, - { - let protobuf = value.encode_to_vec(); - - let mut response = self - .request( - Method::PUT, - service, - path, - additional_headers, - HttpAuthOverride::NoOverride, - Some(RequestBody { - contents: protobuf, - content_type: "application/x-protobuf".into(), - }), - ) - .await?; - - Self::protobuf(&mut response).await + #[async_trait] + impl ProtobufResponseExt for Response { + async fn protobuf( + self, + ) -> Result { + let body = self.bytes().await?; + let decoded = T::decode(body)?; + Ok(decoded) + } } +} +impl PushService { pub async fn ws( &mut self, path: &str, keepalive_path: &str, - additional_headers: &[(&str, &str)], + additional_headers: &[(&'static str, &str)], credentials: Option, ) -> Result { let span = debug_span!("websocket"); - let (ws, stream) = TungsteniteWebSocket::with_tls_config( - Self::tls_config(&self.cfg), - self.cfg.base_url(Endpoint::Service), - path, - additional_headers, - credentials.as_ref(), - ) - .instrument(span.clone()) - .await?; + + let endpoint = self.cfg.base_url(Endpoint::Service); + let mut url = endpoint.join(path).expect("valid url"); + url.set_scheme("wss").expect("valid https base url"); + + if let Some(credentials) = credentials { + url.query_pairs_mut() + .append_pair("login", &credentials.login()) + .append_pair( + "password", + credentials.password.as_ref().expect("a password"), + ); + } + + let mut builder = self.client.get(url); + for (key, value) in additional_headers { + builder = builder.header(*key, *value); + } + + let ws = builder + .upgrade() + .send() + .await? + .into_websocket() + .instrument(span.clone()) + .await?; + let (ws, task) = - SignalWebSocket::from_socket(ws, stream, keepalive_path.to_owned()); + SignalWebSocket::from_socket(ws, keepalive_path.to_owned()); let task = task.instrument(span); tokio::task::spawn(task); Ok(ws) @@ -665,13 +391,17 @@ impl PushService { &mut self, credentials: HttpAuth, ) -> Result { - self.get_protobuf( + self.request( + Method::GET, Endpoint::Storage, "/v1/groups/", - &[], HttpAuthOverride::Identified(credentials), - ) + )? + .send() + .await? + .protobuf() .await + .map_err(Into::into) } } diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index 0a444fef6..b05aef56c 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -1,3 +1,4 @@ +use reqwest::Method; use serde::{Deserialize, Serialize}; use zkgroup::profiles::{ProfileKeyCommitment, ProfileKeyVersion}; @@ -5,12 +6,12 @@ use crate::{ configuration::Endpoint, content::ServiceError, profile_cipher::ProfileCipherError, - push_service::{AvatarWrite, HttpAuthOverride}, + push_service::AvatarWrite, utils::{serde_base64, serde_optional_base64}, Profile, ServiceAddress, }; -use super::{DeviceCapabilities, PushService}; +use super::{DeviceCapabilities, HttpAuthOverride, PushService, ReqwestExt}; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -102,13 +103,17 @@ impl PushService { format!("/v1/profile/{}", address.uuid) }; // TODO: set locale to en_US - self.get_json( + self.request( + Method::GET, Endpoint::Service, &endpoint, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } pub async fn retrieve_profile_avatar( @@ -161,35 +166,32 @@ impl PushService { }; // XXX this should be a struct; cfr ProfileAvatarUploadAttributes - let response: Result = self - .put_json( + let upload_url: Result = self + .request( + Method::PUT, Endpoint::Service, "/v1/profile", - &[], HttpAuthOverride::NoOverride, - command, - ) + )? + .json(&command) + .send_to_signal() + .await? + .json() .await; - match (response, avatar) { - (Ok(_url), AvatarWrite::NewAvatar(_avatar)) => { + + match (upload_url, avatar) { + (_url, AvatarWrite::NewAvatar(_avatar)) => { // FIXME unreachable!("Uploading avatar unimplemented"); }, // FIXME cleanup when #54883 is stable and MSRV: // or-patterns syntax is experimental // see issue #54883 for more information - ( - Err(ServiceError::JsonDecodeError { .. }), - AvatarWrite::RetainAvatar, - ) - | ( - Err(ServiceError::JsonDecodeError { .. }), - AvatarWrite::NoAvatar, - ) => { + (Err(_), AvatarWrite::RetainAvatar) + | (Err(_), AvatarWrite::NoAvatar) => { // OWS sends an empty string when there's no attachment Ok(None) }, - (Err(e), _) => Err(e), (Ok(_resp), AvatarWrite::RetainAvatar) | (Ok(_resp), AvatarWrite::NoAvatar) => { tracing::warn!( diff --git a/src/push_service/registration.rs b/src/push_service/registration.rs index 6d353b962..f968b6add 100644 --- a/src/push_service/registration.rs +++ b/src/push_service/registration.rs @@ -1,15 +1,27 @@ +use derivative::Derivative; use libsignal_protocol::IdentityKey; +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{AccountAttributes, AuthCredentials, PushService, ServiceError}; +use super::{AccountAttributes, PushService, ServiceError}; use crate::{ configuration::Endpoint, pre_keys::{KyberPreKeyEntity, SignedPreKeyEntity}, - push_service::HttpAuthOverride, + push_service::{HttpAuthOverride, ReqwestExt}, utils::serde_base64, }; +/// This type is used in registration lock handling. +/// It's identical with HttpAuth, but used to avoid type confusion. +#[derive(Derivative, Clone, Serialize, Deserialize)] +#[derivative(Debug)] +pub struct AuthCredentials { + pub username: String, + #[derivative(Debug = "ignore")] + pub password: String, +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct RegistrationLockFailure { @@ -31,21 +43,13 @@ pub struct VerifyAccountResponse { pub number: Option, } -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] pub enum VerificationTransport { Sms, Voice, } -impl VerificationTransport { - pub fn as_str(&self) -> &str { - match self { - Self::Sms => "sms", - Self::Voice => "voice", - } - } -} - #[derive(Clone, Debug)] pub enum RegistrationMethod<'a> { SessionId(&'a str), @@ -148,7 +152,13 @@ impl PushService { device_activation_request: DeviceActivationRequest, } - let req = RegistrationSessionRequestBody { + self.request( + Method::POST, + Endpoint::Service, + "/v1/registration", + HttpAuthOverride::NoOverride, + )? + .json(&RegistrationSessionRequestBody { session_id: registration_method.session_id(), recovery_password: registration_method.recovery_password(), account_attributes, @@ -157,18 +167,12 @@ impl PushService { pni_identity_key: pni_identity_key.serialize().into(), device_activation_request, every_signed_key_valid: true, - }; - - let res: VerifyAccountResponse = self - .post_json( - Endpoint::Service, - "/v1/registration", - &[], - HttpAuthOverride::NoOverride, - req, - ) - .await?; - Ok(res) + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } // Equivalent of Java's @@ -190,24 +194,24 @@ impl PushService { push_token_type: Option<&'a str>, } - let req = VerificationSessionMetadataRequestBody { + self.request( + Method::POST, + Endpoint::Service, + "/v1/verification/session", + HttpAuthOverride::Unidentified, + )? + .json(&VerificationSessionMetadataRequestBody { number, push_token_type: push_token.as_ref().map(|_| "fcm"), push_token, mcc, mnc, - }; - - let res: RegistrationSessionMetadataResponse = self - .post_json( - Endpoint::Service, - "/v1/verification/session", - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } pub async fn patch_verification_session<'a>( @@ -230,25 +234,25 @@ impl PushService { push_token_type: Option<&'a str>, } - let req = UpdateVerificationSessionRequestBody { + self.request( + Method::PATCH, + Endpoint::Service, + &format!("/v1/verification/session/{}", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&UpdateVerificationSessionRequestBody { captcha, push_token_type: push_token.as_ref().map(|_| "fcm"), push_token, mcc, mnc, push_challenge, - }; - - let res: RegistrationSessionMetadataResponse = self - .patch_json( - Endpoint::Service, - &format!("/v1/verification/session/{}", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } // Equivalent of Java's @@ -271,20 +275,24 @@ impl PushService { // locale: Option, transport: VerificationTransport, ) -> Result { - let mut req = std::collections::HashMap::new(); - req.insert("transport", transport.as_str()); - req.insert("client", client); + #[derive(Debug, Serialize)] + struct VerificationCodeRequest<'a> { + transport: VerificationTransport, + client: &'a str, + } - let res: RegistrationSessionMetadataResponse = self - .post_json( - Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + self.request( + Method::POST, + Endpoint::Service, + &format!("/v1/verification/session/{}/code", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&VerificationCodeRequest { transport, client }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } pub async fn submit_verification_code( @@ -292,18 +300,24 @@ impl PushService { session_id: &str, verification_code: &str, ) -> Result { - let mut req = std::collections::HashMap::new(); - req.insert("code", verification_code); + #[derive(Debug, Serialize)] + struct VerificationCode<'a> { + code: &'a str, + } - let res: RegistrationSessionMetadataResponse = self - .put_json( - Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + self.request( + Method::PUT, + Endpoint::Service, + &format!("/v1/verification/session/{}/code", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&VerificationCode { + code: verification_code, + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/sender.rs b/src/sender.rs index ae8a862e1..e755d8460 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -222,9 +222,10 @@ where .get_attachment_v2_upload_attributes() .instrument(tracing::trace_span!("requesting upload attributes")) .await?; + let (id, digest) = self .service - .upload_attachment(&attrs, &mut std::io::Cursor::new(&contents)) + .upload_attachment(attrs, &mut std::io::Cursor::new(&contents)) .instrument(tracing::trace_span!("Uploading attachment")) .await?; diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 4736edc45..b74d02ff1 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -9,18 +9,18 @@ use futures::channel::{mpsc, oneshot}; use futures::future::BoxFuture; use futures::prelude::*; use futures::stream::FuturesUnordered; -use prost::Message; +use reqwest_websocket::WebSocket; use serde::{Deserialize, Serialize}; +use tokio::time::Instant; -use crate::messagepipe::{WebSocketService, WebSocketStreamItem}; use crate::proto::{ web_socket_message, WebSocketMessage, WebSocketRequestMessage, WebSocketResponseMessage, }; -use crate::push_service::{MismatchedDevices, ServiceError}; +use crate::push_service::{self, MismatchedDevices, ServiceError}; mod sender; -pub(crate) mod tungstenite; +// pub(crate) mod tungstenite; type RequestStreamItem = ( WebSocketRequestMessage, @@ -61,7 +61,7 @@ struct SignalWebSocketInner { stream: Option, } -struct SignalWebSocketProcess { +struct SignalWebSocketProcess { /// Whether to enable keep-alive or not (and send a request to this path) keep_alive_path: String, @@ -85,16 +85,16 @@ struct SignalWebSocketProcess { >, // WS backend stuff - ws: WS, - stream: WS::Stream, + ws: WebSocket, } -impl SignalWebSocketProcess { +impl SignalWebSocketProcess { async fn process_frame( &mut self, - frame: Bytes, + frame: Vec, ) -> Result<(), ServiceError> { - let msg = WebSocketMessage::decode(frame)?; + use prost::Message; + let msg = WebSocketMessage::decode(Bytes::from(frame))?; if let Some(request) = &msg.request { tracing::trace!( "decoded WebSocketMessage request {{ r#type: {:?}, verb: {:?}, path: {:?}, body: {} bytes, headers: {:?}, id: {:?} }}", @@ -128,7 +128,7 @@ impl SignalWebSocketProcess { let (sink, recv) = oneshot::channel(); tracing::trace!("sending request with body"); self.request_sink.send((request, sink)).await.map_err( - |_| ServiceError::WsError { + |_| ServiceError::WsClosing { reason: "request handler failed".into(), }, )?; @@ -155,10 +155,11 @@ impl SignalWebSocketProcess { } else if let Some(_x) = self.outgoing_keep_alive_set.take(&id) { - if response.status() != 200 { + let status_code = response.status(); + if status_code != 200 { tracing::warn!( - "Response code for keep-alive is not 200: {:?}", - response + status_code, + "response code for keep-alive != 200" ); return Err(ServiceError::UnhandledResponseCode { http_code: response.status() as u16, @@ -166,8 +167,8 @@ impl SignalWebSocketProcess { } } else { tracing::warn!( - "Response for non existing request: {:?}", - response + ?response, + "response for non existing request" ); } } @@ -193,12 +194,40 @@ impl SignalWebSocketProcess { } async fn run(mut self) -> Result<(), ServiceError> { - loop { + let mut ka_interval = tokio::time::interval_at( + Instant::now(), + push_service::KEEPALIVE_TIMEOUT_SECONDS, + ); + + Ok(loop { futures::select! { + _ = ka_interval.tick().fuse() => { + use prost::Message; + tracing::debug!("sending keep-alive"); + let request = WebSocketRequestMessage { + id: Some(self.next_request_id()), + path: Some(self.keep_alive_path.clone()), + verb: Some("GET".into()), + ..Default::default() + }; + self.outgoing_keep_alive_set.insert(request.id.unwrap()); + let msg = WebSocketMessage { + r#type: Some(web_socket_message::Type::Request.into()), + request: Some(request), + ..Default::default() + }; + let buffer = msg.encode_to_vec(); + if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await { + tracing::info!("Websocket sink has closed: {:?}.", e); + break; + }; + }, // Process requests from the application, forward them to Signal x = self.requests.next() => { match x { Some((mut request, responder)) => { + use prost::Message; + // Regenerate ID if already in the table request.id = Some( request @@ -222,47 +251,44 @@ impl SignalWebSocketProcess { ..Default::default() }; let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await? + self.ws.send(reqwest_websocket::Message::Binary(buffer)).await? } None => { - return Err(ServiceError::WsError { - reason: "SignalWebSocket: end of application request stream; socket closing".into() + return Err(ServiceError::WsClosing { + reason: "SignalWebSocket: end of application request stream; socket closing" }); } } } - web_socket_item = self.stream.next() => { + web_socket_item = self.ws.next().fuse() => { + use reqwest_websocket::Message; match web_socket_item { - Some(WebSocketStreamItem::Message(frame)) => { + Some(Ok(Message::Close { code, reason })) => { + tracing::warn!(%code, reason, "websocket closed"); + break; + }, + Some(Ok(Message::Binary(frame))) => { self.process_frame(frame).await?; } - Some(WebSocketStreamItem::KeepAliveRequest) => { - // XXX: would be nicer if we could drop this request into the request - // queue above. - tracing::debug!("Sending keep alive upon request"); - let request = WebSocketRequestMessage { - id: Some(self.next_request_id()), - path: Some(self.keep_alive_path.clone()), - verb: Some("GET".into()), - ..Default::default() - }; - self.outgoing_keep_alive_set.insert(request.id.unwrap()); - let msg = WebSocketMessage { - r#type: Some(web_socket_message::Type::Request.into()), - request: Some(request), - ..Default::default() - }; - let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await?; + Some(Ok(Message::Ping(_))) => { + tracing::trace!("received ping"); } + Some(Ok(Message::Pong(_))) => { + tracing::trace!("received pong"); + } + Some(Ok(Message::Text(_))) => { + tracing::trace!("received text (unsupported, skipping)"); + } + Some(Err(e)) => return Err(ServiceError::WsError(e)), None => { - return Err(ServiceError::WsError { - reason: "end of web request stream; socket closing".into() + return Err(ServiceError::WsClosing { + reason: "end of web request stream; socket closing" }); } } } response = self.outgoing_responses.next() => { + use prost::Message; match response { Some(Ok(response)) => { tracing::trace!("sending response {:?}", response); @@ -273,7 +299,7 @@ impl SignalWebSocketProcess { ..Default::default() }; let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await?; + self.ws.send(buffer.into()).await?; } Some(Err(e)) => { tracing::error!("could not generate response to a Signal request; responder was canceled: {}. Continuing.", e); @@ -284,7 +310,7 @@ impl SignalWebSocketProcess { } } } - } + }) } } @@ -293,9 +319,8 @@ impl SignalWebSocket { self.inner.lock().unwrap() } - pub fn from_socket( - ws: WS, - stream: WS::Stream, + pub fn from_socket( + ws: WebSocket, keep_alive_path: String, ) -> (Self, impl Future) { // Create process @@ -316,7 +341,6 @@ impl SignalWebSocket { .into_iter() .collect(), ws, - stream, }; let process = process.run().map(|x| match x { Ok(()) => (), @@ -389,15 +413,14 @@ impl SignalWebSocket { async move { if let Err(_e) = request_sink.send((r, sink)).await { return Err(ServiceError::WsClosing { - reason: "WebSocket closing while sending request.".into(), + reason: "WebSocket closing while sending request.", }); } // Handle the oneshot sender error for dropped senders. match recv.await { Ok(x) => x, Err(_) => Err(ServiceError::WsClosing { - reason: "WebSocket closing while waiting for a response." - .into(), + reason: "WebSocket closing while waiting for a response.", }), } } diff --git a/src/websocket/tungstenite.rs b/src/websocket/tungstenite.rs index ae40caf09..5c2b56673 100644 --- a/src/websocket/tungstenite.rs +++ b/src/websocket/tungstenite.rs @@ -127,10 +127,7 @@ impl TungsteniteWebSocket { path: &str, additional_headers: &[(&str, &str)], credentials: Option<&ServiceCredentials>, - ) -> Result< - (Self, ::Stream), - TungsteniteWebSocketError, - > { + ) -> Result { let mut url = base_url.borrow().join(path).expect("valid url"); url.set_scheme("wss").expect("valid https base url"); @@ -188,11 +185,11 @@ impl TungsteniteWebSocket { } } -#[async_trait::async_trait] -impl WebSocketService for TungsteniteWebSocket { - type Stream = Receiver; - - async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError> { +impl TungsteniteWebSocket { + pub async fn send_message( + &mut self, + msg: Bytes, + ) -> Result<(), ServiceError> { self.socket_sink .send(Message::Binary(msg.to_vec())) .await