From 9782968a7242b6cf0234cfbddf7a625857b45e51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Fri, 18 Oct 2024 23:20:44 +0200 Subject: [PATCH] Switch to reqwest (#332) Also: * Switch to reqwest and reqwest-websocket * Use PushService instead of websocket like in libsignal.kt * Introduce trait to share error handling between PushService and WebSocketService --- Cargo.toml | 16 +- src/account_manager.rs | 83 +++-- src/cipher.rs | 26 +- src/envelope.rs | 4 +- src/groups_v2/manager.rs | 32 +- src/messagepipe.rs | 33 +- src/profile_service.rs | 40 +- src/provisioning/mod.rs | 4 +- src/provisioning/pipe.rs | 16 +- src/push_service/account.rs | 55 ++- src/push_service/cdn.rs | 224 +++++------- src/push_service/error.rs | 28 +- src/push_service/keys.rs | 129 ++++--- src/push_service/linking.rs | 33 +- src/push_service/mod.rs | 602 +++++++------------------------ src/push_service/profile.rs | 48 +-- src/push_service/registration.rs | 176 +++++---- src/push_service/response.rs | 160 ++++++++ src/receiver.rs | 2 +- src/sender.rs | 21 +- src/websocket/mod.rs | 303 ++++++---------- src/websocket/request.rs | 52 +++ src/websocket/sender.rs | 21 +- src/websocket/tungstenite.rs | 202 ----------- 24 files changed, 960 insertions(+), 1350 deletions(-) create mode 100644 src/push_service/response.rs create mode 100644 src/websocket/request.rs delete mode 100644 src/websocket/tungstenite.rs diff --git a/Cargo.toml b/Cargo.toml index c3213bb4e..cc43f204a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,22 +35,14 @@ url = { version = "2.1", features = ["serde"] } uuid = { version = "1", features = ["serde"] } # http -hyper = "1.0" -hyper-util = { version = "0.1", features = ["client", "client-legacy"] } -hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "ring", "logging"] } -hyper-timeout = "0.5" -headers = "0.4" -http-body-util = "0.1" -mpart-async = "0.7" -async-tungstenite = { version = "0.27", features = ["tokio-rustls-native-certs", "url"] } -tokio = { version = "1.0", features = ["macros"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } - -rustls-pemfile = "2.0" +reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls-manual-roots", "stream"] } +reqwest-websocket = { version = "0.4.2", features = ["json"] } tracing = { version = "0.1", features = ["log"] } tracing-futures = "0.2" +tokio = { version = "1.0", features = ["macros"] } + [build-dependencies] prost-build = "0.13" diff --git a/src/account_manager.rs b/src/account_manager.rs index d8251c92b..9e69aaccb 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -1,5 +1,6 @@ use base64::prelude::*; use phonenumber::PhoneNumber; +use reqwest::Method; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; @@ -28,9 +29,9 @@ use crate::proto::sync_message::PniChangeNumber; use crate::proto::{DeviceName, SyncMessage}; use crate::provisioning::generate_registration_id; use crate::push_service::{ - AvatarWrite, DeviceActivationRequest, DeviceInfo, RecaptchaAttributes, - RegistrationMethod, ServiceIdType, VerifyAccountResponse, - DEFAULT_DEVICE_ID, + AvatarWrite, DeviceActivationRequest, DeviceInfo, HttpAuthOverride, + RecaptchaAttributes, RegistrationMethod, ReqwestExt, ServiceIdType, + VerifyAccountResponse, DEFAULT_DEVICE_ID, }; use crate::sender::OutgoingPushMessage; use crate::session_store::SessionStoreExt; @@ -44,9 +45,7 @@ use crate::{ profile_name::ProfileName, proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion}, provisioning::{ProvisioningCipher, ProvisioningError}, - push_service::{ - AccountAttributes, HttpAuthOverride, PushService, ServiceError, - }, + push_service::{AccountAttributes, PushService, ServiceError}, utils::serde_base64, }; @@ -224,13 +223,19 @@ impl AccountManager { let dc: DeviceCode = self .service - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/devices/provisioning/code", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await?; + Ok(dc.verification_code) } @@ -247,16 +252,21 @@ impl AccountManager { let body = env.encode_to_vec(); self.service - .put_json( + .request( + Method::PUT, Endpoint::Service, - &format!("/v1/provisioning/{}", destination), - &[], + format!("/v1/provisioning/{}", destination), HttpAuthOverride::NoOverride, - &ProvisioningMessage { - body: BASE64_RELAXED.encode(body), - }, - ) - .await + )? + .json(&ProvisioningMessage { + body: BASE64_RELAXED.encode(body), + }) + .send() + .await? + .service_error_for_status() + .await?; + + Ok(()) } /// Link a new device, given a tsurl. @@ -582,15 +592,18 @@ impl AccountManager { } self.service - .put_json::<(), _>( + .request( + Method::PUT, Endpoint::Service, "/v1/accounts/name", - &[], HttpAuthOverride::NoOverride, - Data { - device_name: encrypted_device_name.encode_to_vec(), - }, - ) + )? + .json(&Data { + device_name: encrypted_device_name.encode_to_vec(), + }) + .send() + .await? + .service_error_for_status() .await?; Ok(()) @@ -607,20 +620,24 @@ impl AccountManager { token: &str, captcha: &str, ) -> Result<(), ServiceError> { - let payload = RecaptchaAttributes { - r#type: String::from("recaptcha"), - token: String::from(token), - captcha: String::from(captcha), - }; self.service - .put_json( + .request( + Method::PUT, Endpoint::Service, "/v1/challenge", - &[], HttpAuthOverride::NoOverride, - payload, - ) - .await + )? + .json(&RecaptchaAttributes { + r#type: String::from("recaptcha"), + token: String::from(token), + captcha: String::from(captcha), + }) + .send() + .await? + .service_error_for_status() + .await?; + + Ok(()) } /// Initialize PNI on linked devices. diff --git a/src/cipher.rs b/src/cipher.rs index 09649b679..057cc8332 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -134,10 +134,9 @@ where let ciphertext = if let Some(msg) = envelope.content.as_ref() { msg } else { - return Err(ServiceError::InvalidFrameError { + return Err(ServiceError::InvalidFrame { reason: - "Envelope should have either a legacy message or content." - .into(), + "envelope should have either a legacy message or content.", }); }; @@ -311,11 +310,8 @@ where }, _ => { // else - return Err(ServiceError::InvalidFrameError { - reason: format!( - "Envelope has unknown type {:?}.", - envelope.r#type() - ), + return Err(ServiceError::InvalidFrame { + reason: "envelope has unknown type", }); }, }; @@ -408,9 +404,7 @@ struct Plaintext { #[allow(clippy::comparison_chain)] fn add_padding(version: u32, contents: &[u8]) -> Result, ServiceError> { if version < 2 { - Err(ServiceError::InvalidFrameError { - reason: format!("Unknown version {}", version), - }) + Err(ServiceError::PaddingVersion(version)) } else if version == 2 { Ok(contents.to_vec()) } else { @@ -436,8 +430,8 @@ fn strip_padding_version( contents: &mut Vec, ) -> Result<(), ServiceError> { if version < 2 { - Err(ServiceError::InvalidFrameError { - reason: format!("Unknown version {}", version), + Err(ServiceError::InvalidFrame { + reason: "unknown version", }) } else if version == 2 { Ok(()) @@ -449,11 +443,7 @@ fn strip_padding_version( #[allow(clippy::comparison_chain)] fn strip_padding(contents: &mut Vec) -> Result<(), ServiceError> { - let new_length = Iso7816::raw_unpad(contents) - .map_err(|e| ServiceError::InvalidFrameError { - reason: format!("Invalid message padding: {:?}", e), - })? - .len(); + let new_length = Iso7816::raw_unpad(contents)?.len(); contents.resize(new_length, 0); Ok(()) } diff --git a/src/envelope.rs b/src/envelope.rs index 1a33de669..ff247c9a3 100644 --- a/src/envelope.rs +++ b/src/envelope.rs @@ -42,8 +42,8 @@ impl Envelope { if input.len() < VERSION_LENGTH || input[VERSION_OFFSET] != SUPPORTED_VERSION { - return Err(ServiceError::InvalidFrameError { - reason: "Unsupported signaling cryptogram version".into(), + return Err(ServiceError::InvalidFrame { + reason: "unsupported signaling cryptogram version", }); } diff --git a/src/groups_v2/manager.rs b/src/groups_v2/manager.rs index d5e77dffd..a1c7af0be 100644 --- a/src/groups_v2/manager.rs +++ b/src/groups_v2/manager.rs @@ -2,11 +2,13 @@ use std::{collections::HashMap, convert::TryInto}; use crate::{ configuration::Endpoint, - groups_v2::model::{Group, GroupChanges}, - groups_v2::operations::{GroupDecodingError, GroupOperations}, + groups_v2::{ + model::{Group, GroupChanges}, + operations::{GroupDecodingError, GroupOperations}, + }, prelude::{PushService, ServiceError}, proto::GroupContextV2, - push_service::{HttpAuth, HttpAuthOverride, ServiceIds}, + push_service::{HttpAuth, HttpAuthOverride, ReqwestExt, ServiceIds}, utils::BASE64_RELAXED, }; @@ -15,6 +17,7 @@ use bytes::Bytes; use chrono::{Days, NaiveDate, NaiveTime, Utc}; use futures::AsyncReadExt; use rand::RngCore; +use reqwest::Method; use serde::Deserialize; use zkgroup::{ auth::AuthCredentialWithPniResponse, @@ -165,20 +168,24 @@ impl GroupsManager { let credentials_response: CredentialResponse = self .push_service - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await?; self.credentials_cache .write(credentials_response.parse()?)?; self.credentials_cache.get(&today)?.ok_or_else(|| { - ServiceError::ResponseError { + ServiceError::InvalidFrame { reason: - "credentials received did not contain requested day" - .into(), + "credentials received did not contain requested day", } })? }; @@ -279,12 +286,7 @@ impl GroupsManager { .retrieve_groups_v2_profile_avatar(path) .await?; let mut result = Vec::with_capacity(10 * 1024 * 1024); - encrypted_avatar - .read_to_end(&mut result) - .await - .map_err(|e| ServiceError::ResponseError { - reason: format!("reading avatar data: {}", e), - })?; + encrypted_avatar.read_to_end(&mut result).await?; Ok(GroupOperations::new(group_secret_params).decrypt_avatar(&result)) } diff --git a/src/messagepipe.rs b/src/messagepipe.rs index cdb77b7a8..9d4e89efa 100644 --- a/src/messagepipe.rs +++ b/src/messagepipe.rs @@ -1,11 +1,9 @@ -use bytes::Bytes; use futures::{ channel::{ mpsc::{self, Sender}, oneshot, }, prelude::*, - stream::FusedStream, }; pub use crate::{ @@ -18,24 +16,12 @@ pub use crate::{ use crate::{push_service::ServiceError, websocket::SignalWebSocket}; -pub enum WebSocketStreamItem { - Message(Bytes), - KeepAliveRequest, -} - #[derive(Debug)] pub enum Incoming { Envelope(Envelope), QueueEmpty, } -#[async_trait::async_trait] -pub trait WebSocketService { - type Stream: FusedStream + Unpin; - - async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError>; -} - pub struct MessagePipe { ws: SignalWebSocket, credentials: ServiceCredentials, @@ -93,8 +79,8 @@ impl MessagePipe { let body = if let Some(body) = request.body.as_ref() { body } else { - return Err(ServiceError::InvalidFrameError { - reason: "Request without body.".into(), + return Err(ServiceError::InvalidFrame { + reason: "request without body.", }); }; Some(Incoming::Envelope(Envelope::decrypt( @@ -111,7 +97,7 @@ impl MessagePipe { responder .send(response) .map_err(|_| ServiceError::WsClosing { - reason: "could not respond to message pipe request".into(), + reason: "could not respond to message pipe request", })?; Ok(result) @@ -133,16 +119,3 @@ impl MessagePipe { combined.filter_map(|x| async { x }) } } - -/// WebSocketService that panics on every request, mainly for example code. -pub struct PanicingWebSocketService; - -#[allow(clippy::diverging_sub_expression)] -#[async_trait::async_trait] -impl WebSocketService for PanicingWebSocketService { - type Stream = futures::channel::mpsc::Receiver; - - async fn send_message(&mut self, _msg: Bytes) -> Result<(), ServiceError> { - todo!(); - } -} diff --git a/src/profile_service.rs b/src/profile_service.rs index e23a4b51b..016658876 100644 --- a/src/profile_service.rs +++ b/src/profile_service.rs @@ -1,17 +1,21 @@ +use reqwest::Method; + use crate::{ - proto::WebSocketRequestMessage, - push_service::{ServiceError, SignalServiceProfile}, - websocket::SignalWebSocket, + configuration::Endpoint, + prelude::PushService, + push_service::{ + HttpAuthOverride, ReqwestExt, ServiceError, SignalServiceProfile, + }, ServiceAddress, }; pub struct ProfileService { - ws: SignalWebSocket, + push_service: PushService, } impl ProfileService { - pub fn from_socket(ws: SignalWebSocket) -> Self { - ProfileService { ws } + pub fn from_socket(push_service: PushService) -> Self { + ProfileService { push_service } } pub async fn retrieve_profile_by_id( @@ -19,7 +23,7 @@ impl ProfileService { address: ServiceAddress, profile_key: Option, ) -> Result { - let endpoint = match profile_key { + let path = match profile_key { Some(key) => { let version = bincode::serialize(&key.get_profile_key_version( @@ -34,13 +38,19 @@ impl ProfileService { }, }; - let request = WebSocketRequestMessage { - path: Some(endpoint), - verb: Some("GET".into()), - // TODO: set locale to en_US - ..Default::default() - }; - - self.ws.request_json(request).await + self.push_service + .request( + Method::GET, + Endpoint::Service, + path, + HttpAuthOverride::NoOverride, + )? + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 2455409a1..f1c16c589 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -75,8 +75,8 @@ pub enum ProvisioningError { DecodeError(#[from] prost::DecodeError), #[error("Websocket error: {reason}")] WsError { reason: String }, - #[error("Websocket closing: {reason}")] - WsClosing { reason: String }, + #[error("Websocket closing")] + WsClosing, #[error("Service error: {0}")] ServiceError(#[from] ServiceError), #[error("libsignal-protocol error: {0}")] diff --git a/src/provisioning/pipe.rs b/src/provisioning/pipe.rs index 7aa258b9f..0842ec895 100644 --- a/src/provisioning/pipe.rs +++ b/src/provisioning/pipe.rs @@ -105,11 +105,9 @@ impl ProvisioningPipe { ); // acknowledge - responder.send(ok).map_err(|_| { - ProvisioningError::WsClosing { - reason: "could not respond to provision request".into(), - } - })?; + responder + .send(ok) + .map_err(|_| ProvisioningError::WsClosing)?; Ok(Some(ProvisioningStep::Url(provisioning_url))) }, @@ -132,11 +130,9 @@ impl ProvisioningPipe { self.provisioning_cipher.decrypt(provision_envelope)?; // acknowledge - responder.send(ok).map_err(|_| { - ProvisioningError::WsClosing { - reason: "could not respond to provision request".into(), - } - })?; + responder + .send(ok) + .map_err(|_| ProvisioningError::WsClosing)?; Ok(Some(ProvisioningStep::Message(provision_message))) }, diff --git a/src/push_service/account.rs b/src/push_service/account.rs index d36eff544..0cc99ea0b 100644 --- a/src/push_service/account.rs +++ b/src/push_service/account.rs @@ -2,10 +2,13 @@ use std::fmt; use chrono::{DateTime, Utc}; use phonenumber::PhoneNumber; +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{HttpAuthOverride, PushService, ServiceError}; +use super::{ + response::ReqwestExt, HttpAuthOverride, PushService, ServiceError, +}; use crate::{ configuration::Endpoint, utils::{serde_optional_base64, serde_phone_number}, @@ -127,13 +130,19 @@ pub struct WhoAmIResponse { impl PushService { /// Method used to check our own UUID pub async fn whoami(&mut self) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, "/v1/accounts/whoami", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await + .map_err(Into::into) } /// Fetches a list of all devices tied to the authenticated account. @@ -146,12 +155,17 @@ impl PushService { } let devices: DeviceInfoList = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/devices/", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await?; Ok(devices.devices) @@ -166,18 +180,19 @@ impl PushService { "only one of PIN and registration lock can be set." ); - match self - .put_json( - Endpoint::Service, - "/v1/accounts/attributes/", - &[], - HttpAuthOverride::NoOverride, - attributes, - ) - .await - { - Err(ServiceError::JsonDecodeError { .. }) => Ok(()), - r => r, - } + self.request( + Method::PUT, + Endpoint::Service, + "/v1/accounts/attributes/", + HttpAuthOverride::NoOverride, + )? + .json(&attributes) + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index 3ba962a6d..b87e3e8a3 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -1,19 +1,15 @@ use std::io::{self, Read}; -use bytes::Bytes; -use futures::{FutureExt, StreamExt, TryStreamExt}; -use http_body_util::BodyExt; -use hyper::Method; +use futures::TryStreamExt; +use reqwest::{multipart::Part, Method}; use tracing::debug; use crate::{ - configuration::Endpoint, - prelude::AttachmentIdentifier, - proto::AttachmentPointer, - push_service::{HttpAuthOverride, RequestBody}, + configuration::Endpoint, prelude::AttachmentIdentifier, + proto::AttachmentPointer, push_service::HttpAuthOverride, }; -use super::{PushService, ServiceError}; +use super::{response::ReqwestExt, PushService, ServiceError}; #[derive(Debug, serde::Deserialize, Default)] #[serde(rename_all = "camelCase")] @@ -32,170 +28,128 @@ pub struct AttachmentV2UploadAttributes { } impl PushService { + pub async fn get_attachment( + &mut self, + ptr: &AttachmentPointer, + ) -> Result { + let path = match ptr.attachment_identifier.as_ref() { + Some(AttachmentIdentifier::CdnId(id)) => { + format!("attachments/{}", id) + }, + Some(AttachmentIdentifier::CdnKey(key)) => { + format!("attachments/{}", key) + }, + None => { + return Err(ServiceError::InvalidFrame { + reason: "no attachment identifier in pointer", + }); + }, + }; + self.get_from_cdn(ptr.cdn_number(), &path).await + } + #[tracing::instrument(skip(self))] pub(crate) async fn get_from_cdn( &mut self, cdn_id: u32, path: &str, ) -> Result { - let response = self + let response_stream = self .request( Method::GET, Endpoint::Cdn(cdn_id), path, - &[], HttpAuthOverride::Unidentified, // CDN requests are always without authentication - None, - ) - .await?; - - Ok(Box::new( - response - .into_body() - .into_data_stream() - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - .into_async_read(), - )) + )? + .send() + .await? + .service_error_for_status() + .await? + .bytes_stream() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + .into_async_read(); + + Ok(response_stream) } - pub async fn get_attachment_by_id( - &mut self, - id: &str, - cdn_id: u32, - ) -> Result { - let path = format!("attachments/{}", id); - self.get_from_cdn(cdn_id, &path).await - } - - pub async fn get_attachment( - &mut self, - ptr: &AttachmentPointer, - ) -> Result { - match ptr.attachment_identifier.as_ref().unwrap() { - AttachmentIdentifier::CdnId(id) => { - // cdn_number did not exist for this part of the protocol. - // cdn_number(), however, returns 0 when the field does not - // exist. - self.get_attachment_by_id(&format!("{}", id), ptr.cdn_number()) - .await - }, - AttachmentIdentifier::CdnKey(key) => { - self.get_attachment_by_id(key, ptr.cdn_number()).await - }, - } - } - - pub async fn get_attachment_v2_upload_attributes( + pub(crate) async fn get_attachment_v2_upload_attributes( &mut self, ) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, "/v2/attachments/form/upload", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await + .map_err(Into::into) } /// Upload attachment to CDN /// /// Returns attachment ID and the attachment digest - pub async fn upload_attachment<'s, C>( + pub async fn upload_attachment( &mut self, - attrs: &AttachmentV2UploadAttributes, - content: &'s mut C, - ) -> Result<(u64, Vec), ServiceError> - where - C: std::io::Read + Send + 's, - { - let values = [ - ("acl", &attrs.acl as &str), - ("key", &attrs.key), - ("policy", &attrs.policy), - ("Content-Type", "application/octet-stream"), - ("x-amz-algorithm", &attrs.algorithm), - ("x-amz-credential", &attrs.credential), - ("x-amz-date", &attrs.date), - ("x-amz-signature", &attrs.signature), - ]; - - let mut digester = crate::digeststream::DigestingReader::new(content); - - self.post_to_cdn0( - "attachments/", - &values, - Some(("file", &mut digester)), - ) - .await?; - Ok((attrs.attachment_id, digester.finalize())) + attrs: AttachmentV2UploadAttributes, + mut reader: impl Read + Send, + ) -> Result<(u64, Vec), ServiceError> { + let attachment_id = attrs.attachment_id; + let mut digester = + crate::digeststream::DigestingReader::new(&mut reader); + + self.post_to_cdn0("attachments/", attrs, "file".into(), &mut digester) + .await?; + + Ok((attachment_id, digester.finalize())) } - #[tracing::instrument(skip(self, value, file), fields(file = file.as_ref().map(|_| "")))] - pub async fn post_to_cdn0<'s, C>( + #[tracing::instrument(skip(self, upload_attributes, reader))] + pub async fn post_to_cdn0( &mut self, path: &str, - value: &[(&str, &str)], - file: Option<(&str, &'s mut C)>, - ) -> Result<(), ServiceError> - where - C: Read + Send + 's, - { - let mut form = mpart_async::client::MultipartRequest::default(); - - // mpart-async has a peculiar ordering of the form items, - // and Amazon S3 expects them in a very specific order (i.e., the file contents should - // go last. - // - // mpart-async uses a VecDeque internally for ordering the fields in the order given. - // - // https://github.com/cetra3/mpart-async/issues/16 - - for &(k, v) in value { - form.add_field(k, v); - } - - if let Some((filename, file)) = file { - // XXX Actix doesn't cope with none-'static lifetimes - // https://docs.rs/actix-web/3.2.0/actix_web/body/enum.Body.html - let mut buf = Vec::new(); - file.read_to_end(&mut buf) - .expect("infallible Read instance"); - form.add_stream( + upload_attributes: AttachmentV2UploadAttributes, + filename: String, + mut reader: impl Read + Send, + ) -> Result<(), ServiceError> { + let mut buf = Vec::new(); + reader + .read_to_end(&mut buf) + .expect("infallible Read instance"); + + // Amazon S3 expects multipart fields in a very specific order + // DO NOT CHANGE THIS (or do it, but feel the wrath of the gods) + let form = reqwest::multipart::Form::new() + .text("acl", upload_attributes.acl) + .text("key", upload_attributes.key) + .text("policy", upload_attributes.policy) + .text("Content-Type", "application/octet-stream") + .text("x-amz-algorithm", upload_attributes.algorithm) + .text("x-amz-credential", upload_attributes.credential) + .text("x-amz-date", upload_attributes.date) + .text("x-amz-signature", upload_attributes.signature) + .part( "file", - filename, - "application/octet-stream", - futures::future::ok::<_, ()>(Bytes::from(buf)).into_stream(), + Part::stream(buf) + .mime_str("application/octet-stream")? + .file_name(filename), ); - } - - let content_type = - format!("multipart/form-data; boundary={}", form.get_boundary()); - - // XXX Amazon S3 needs the Content-Length, but we don't know it without depleting the whole - // stream. Sadly, Content-Length != contents.len(), but should include the whole form. - let mut body_contents = vec![]; - while let Some(b) = form.next().await { - // Unwrap, because no error type was used above - body_contents.extend(b.unwrap()); - } - tracing::trace!( - "Sending PUT with Content-Type={} and length {}", - content_type, - body_contents.len() - ); let response = self .request( Method::POST, Endpoint::Cdn(0), path, - &[], HttpAuthOverride::NoOverride, - Some(RequestBody { - contents: body_contents, - content_type, - }), - ) + )? + .multipart(form) + .send() + .await? + .service_error_for_status() .await?; debug!("HyperPushService::PUT response: {:?}", response); diff --git a/src/push_service/error.rs b/src/push_service/error.rs index d197667d9..f10f67866 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -1,3 +1,4 @@ +use aes::cipher::block_padding::UnpadError; use libsignal_protocol::SignalProtocolError; use zkgroup::ZkGroupDeserializationFailure; @@ -10,7 +11,7 @@ use super::{ #[derive(thiserror::Error, Debug)] pub enum ServiceError { #[error("Service request timed out: {reason}")] - Timeout { reason: String }, + Timeout { reason: &'static str }, #[error("invalid URL: {0}")] InvalidUrl(#[from] url::ParseError), @@ -18,11 +19,11 @@ pub enum ServiceError { #[error("Error sending request: {reason}")] SendError { reason: String }, - #[error("Error decoding response: {reason}")] - ResponseError { reason: String }, + #[error("i/o error")] + IO(#[from] std::io::Error), - #[error("Error decoding JSON response: {reason}")] - JsonDecodeError { reason: String }, + #[error("Error decoding JSON: {0}")] + JsonDecodeError(#[from] serde_json::Error), #[error("Error decoding protobuf frame: {0}")] ProtobufDecodeError(#[from] prost::DecodeError), #[error("error encoding or decoding bincode: {0}")] @@ -39,13 +40,19 @@ pub enum ServiceError { #[error("Unexpected response: HTTP {http_code}")] UnhandledResponseCode { http_code: u16 }, - #[error("Websocket error: {reason}")] - WsError { reason: String }, + #[error("Websocket error: {0}")] + WsError(#[from] reqwest_websocket::Error), #[error("Websocket closing: {reason}")] - WsClosing { reason: String }, + WsClosing { reason: &'static str }, + + #[error("Invalid padding: {0}")] + Padding(#[from] UnpadError), + + #[error("unknown padding version {0}")] + PaddingVersion(u32), #[error("Invalid frame: {reason}")] - InvalidFrameError { reason: String }, + InvalidFrame { reason: &'static str }, #[error("MAC error")] MacError, @@ -85,4 +92,7 @@ pub enum ServiceError { #[error("invalid device name")] InvalidDeviceName, + + #[error("HTTP reqwest error: {0}")] + Http(#[from] reqwest::Error), } diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index 50894d1f2..f4addb013 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use libsignal_protocol::{IdentityKey, PreKeyBundle, SenderCertificate}; +use reqwest::Method; use serde::Deserialize; use crate::{ @@ -13,8 +14,8 @@ use crate::{ }; use super::{ - HttpAuthOverride, PushService, SenderCertificateJson, ServiceError, - ServiceIdType, VerifyAccountResponse, + response::ReqwestExt, HttpAuthOverride, PushService, SenderCertificateJson, + ServiceError, ServiceIdType, VerifyAccountResponse, }; #[derive(Debug, Deserialize, Default)] @@ -29,13 +30,19 @@ impl PushService { &mut self, service_id_type: ServiceIdType, ) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, - &format!("/v2/keys?identity={}", service_id_type), - &[], + format!("/v2/keys?identity={}", service_id_type), HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await + .map_err(Into::into) } pub async fn register_pre_keys( @@ -43,19 +50,19 @@ impl PushService { service_id_type: ServiceIdType, pre_key_state: PreKeyState, ) -> Result<(), ServiceError> { - match self - .put_json( - Endpoint::Service, - &format!("/v2/keys?identity={}", service_id_type), - &[], - HttpAuthOverride::NoOverride, - pre_key_state, - ) - .await - { - Err(ServiceError::JsonDecodeError { .. }) => Ok(()), - r => r, - } + self.request( + Method::PUT, + Endpoint::Service, + format!("/v2/keys?identity={}", service_id_type), + HttpAuthOverride::NoOverride, + )? + .json(&pre_key_state) + .send() + .await? + .service_error_for_status() + .await?; + + Ok(()) } pub async fn get_pre_key( @@ -67,13 +74,19 @@ impl PushService { format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id); let mut pre_key_response: PreKeyResponse = self - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await?; + assert!(!pre_key_response.devices.is_empty()); let identity = IdentityKey::decode(&pre_key_response.identity_key)?; @@ -92,12 +105,17 @@ impl PushService { format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id) }; let pre_key_response: PreKeyResponse = self - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await?; let mut pre_keys = vec![]; let identity = IdentityKey::decode(&pre_key_response.identity_key)?; @@ -111,12 +129,17 @@ impl PushService { &mut self, ) -> Result { let cert: SenderCertificateJson = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/certificate/delivery", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await?; Ok(SenderCertificate::deserialize(&cert.certificate)?) } @@ -125,12 +148,17 @@ impl PushService { &mut self, ) -> Result { let cert: SenderCertificateJson = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/certificate/delivery?includeE164=false", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await?; Ok(SenderCertificate::deserialize(&cert.certificate)?) } @@ -160,23 +188,26 @@ impl PushService { pni_registration_ids: HashMap, signature_valid_on_each_signed_pre_key: bool, } - - let res: VerifyAccountResponse = self - .put_json( - Endpoint::Service, - "/v2/accounts/phone_number_identity_key_distribution", - &[], - HttpAuthOverride::NoOverride, - PniKeyDistributionRequest { - pni_identity_key: pni_identity_key.serialize().into(), - device_messages, - device_pni_signed_prekeys, - device_pni_last_resort_kyber_prekeys, - pni_registration_ids, - signature_valid_on_each_signed_pre_key, - }, - ) - .await?; - Ok(res) + self.request( + Method::PUT, + Endpoint::Service, + "/v2/accounts/phone_number_identity_key_distribution", + HttpAuthOverride::NoOverride, + )? + .json(&PniKeyDistributionRequest { + pni_identity_key: pni_identity_key.serialize().into(), + device_messages, + device_pni_signed_prekeys, + device_pni_last_resort_kyber_prekeys, + pni_registration_ids, + signature_valid_on_each_signed_pre_key, + }) + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/push_service/linking.rs b/src/push_service/linking.rs index 5d9b9ded8..5d5026c6c 100644 --- a/src/push_service/linking.rs +++ b/src/push_service/linking.rs @@ -1,11 +1,12 @@ +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::configuration::Endpoint; use super::{ - DeviceActivationRequest, HttpAuth, HttpAuthOverride, PushService, - ServiceError, + response::ReqwestExt, DeviceActivationRequest, HttpAuth, HttpAuthOverride, + PushService, ServiceError, }; #[derive(Debug, Serialize)] @@ -59,18 +60,34 @@ impl PushService { link_request: &LinkRequest, http_auth: HttpAuth, ) -> Result { - self.put_json( + self.request( + Method::PUT, Endpoint::Service, "/v1/devices/link", - &[], HttpAuthOverride::Identified(http_auth), - link_request, - ) + )? + .json(&link_request) + .send() + .await? + .service_error_for_status() + .await? + .json() .await + .map_err(Into::into) } pub async fn unlink_device(&mut self, id: i64) -> Result<(), ServiceError> { - self.delete_json(Endpoint::Service, &format!("/v1/devices/{}", id), &[]) - .await + self.request( + Method::DELETE, + Endpoint::Service, + format!("/v1/devices/{}", id), + HttpAuthOverride::NoOverride, + )? + .send() + .await? + .service_error_for_status() + .await?; + + Ok(()) } } diff --git a/src/push_service/mod.rs b/src/push_service/mod.rs index f839f185b..c6cf31fd5 100644 --- a/src/push_service/mod.rs +++ b/src/push_service/mod.rs @@ -1,36 +1,23 @@ -use std::{io, time::Duration}; +use std::time::Duration; use crate::{ configuration::{Endpoint, ServiceCredentials}, pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity}, prelude::ServiceConfiguration, utils::serde_base64, - websocket::{tungstenite::TungsteniteWebSocket, SignalWebSocket}, + websocket::SignalWebSocket, }; -use bytes::{Buf, Bytes}; use derivative::Derivative; -use headers::{Authorization, HeaderMapExt}; -use http_body_util::{BodyExt, Full}; -use hyper::{ - body::Incoming, - header::{CONTENT_LENGTH, CONTENT_TYPE, USER_AGENT}, - Method, Request, Response, StatusCode, -}; -use hyper_rustls::HttpsConnector; -use hyper_timeout::TimeoutConnector; -use hyper_util::{ - client::legacy::{connect::HttpConnector, Client}, - rt::TokioExecutor, -}; use libsignal_protocol::{ error::SignalProtocolError, kem::{Key, Public}, IdentityKey, PreKeyBundle, PublicKey, }; -use prost::Message as ProtobufMessage; +use protobuf::ProtobufResponseExt; +use reqwest::{Method, RequestBuilder}; +use reqwest_websocket::RequestBuilderExt; use serde::{Deserialize, Serialize}; -use tokio_rustls::rustls; use tracing::{debug_span, Instrument}; pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55); @@ -43,6 +30,7 @@ mod keys; mod linking; mod profile; mod registration; +mod response; mod stickers; pub use account::*; @@ -52,6 +40,7 @@ pub use keys::*; pub use linking::*; pub use profile::*; pub use registration::*; +pub(crate) use response::{ReqwestExt, SignalServiceResponse}; #[derive(Debug, Serialize, Deserialize)] pub struct ProofRequired { @@ -67,16 +56,6 @@ pub struct HttpAuth { pub password: String, } -/// This type is used in registration lock handling. -/// It's identical with HttpAuth, but used to avoid type confusion. -#[derive(Derivative, Clone, Serialize, Deserialize)] -#[derivative(Debug)] -pub struct AuthCredentials { - pub username: String, - #[derivative(Debug = "ignore")] - pub password: String, -} - #[derive(Debug, Clone)] pub enum HttpAuthOverride { NoOverride, @@ -164,514 +143,182 @@ pub struct StaleDevices { pub stale_devices: Vec, } -#[derive(Debug)] -struct RequestBody { - contents: Vec, - content_type: String, -} - #[derive(Clone)] pub struct PushService { cfg: ServiceConfiguration, - user_agent: String, credentials: Option, - client: - Client>, Full>, + client: reqwest::Client, } impl PushService { pub fn new( cfg: impl Into, credentials: Option, - user_agent: String, + user_agent: impl AsRef, ) -> Self { let cfg = cfg.into(); - let tls_config = Self::tls_config(&cfg); - - let https = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(tls_config) - .https_only() - .enable_http1() - .build(); - - // as in Signal-Android - let mut timeout_connector = TimeoutConnector::new(https); - timeout_connector.set_connect_timeout(Some(Duration::from_secs(10))); - timeout_connector.set_read_timeout(Some(Duration::from_secs(65))); - timeout_connector.set_write_timeout(Some(Duration::from_secs(65))); - - let client: Client<_, Full> = - Client::builder(TokioExecutor::new()).build(timeout_connector); + let client = reqwest::ClientBuilder::new() + .add_root_certificate( + reqwest::Certificate::from_pem( + cfg.certificate_authority.as_bytes(), + ) + .unwrap(), + ) + .connect_timeout(Duration::from_secs(10)) + .timeout(Duration::from_secs(65)) + .user_agent(user_agent.as_ref()) + .build() + .unwrap(); Self { cfg, credentials: credentials.and_then(|c| c.authorization()), client, - user_agent, } } - fn tls_config(cfg: &ServiceConfiguration) -> rustls::ClientConfig { - let mut cert_bytes = io::Cursor::new(&cfg.certificate_authority); - let roots = rustls_pemfile::certs(&mut cert_bytes); - - let mut root_certs = rustls::RootCertStore::empty(); - root_certs.add_parsable_certificates( - roots.map(|c| c.expect("parsable PEM files")), - ); - - rustls::ClientConfig::builder() - .with_root_certificates(root_certs) - .with_no_client_auth() - } - - #[tracing::instrument(skip(self, path, body), fields(path = %path.as_ref()))] - async fn request( + #[tracing::instrument(skip(self, path), fields(path = %path.as_ref()))] + pub fn request( &self, method: Method, endpoint: Endpoint, path: impl AsRef, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - body: Option, - ) -> Result, ServiceError> { + auth_override: HttpAuthOverride, + ) -> Result { let url = self.cfg.base_url(endpoint).join(path.as_ref())?; - let mut builder = Request::builder() - .method(method) - .uri(url.as_str()) - .header(USER_AGENT, &self.user_agent); - - for (header, value) in additional_headers { - builder = builder.header(*header, *value); - } + let mut builder = self.client.request(method, url); - match credentials_override { + builder = match auth_override { HttpAuthOverride::NoOverride => { if let Some(HttpAuth { username, password }) = self.credentials.as_ref() { + builder.basic_auth(username, Some(password)) + } else { builder - .headers_mut() - .unwrap() - .typed_insert(Authorization::basic(username, password)); } }, HttpAuthOverride::Identified(HttpAuth { username, password }) => { - builder - .headers_mut() - .unwrap() - .typed_insert(Authorization::basic(&username, &password)); + builder.basic_auth(username, Some(password)) }, - HttpAuthOverride::Unidentified => (), + HttpAuthOverride::Unidentified => builder, }; - let request = if let Some(RequestBody { - contents, - content_type, - }) = body - { - builder - .header(CONTENT_LENGTH, contents.len() as u64) - .header(CONTENT_TYPE, content_type) - .body(Full::new(Bytes::from(contents))) - .unwrap() - } else { - builder.body(Full::default()).unwrap() - }; - - let mut response = self.client.request(request).await.map_err(|e| { - ServiceError::SendError { - reason: e.to_string(), - } - })?; - - match response.status() { - StatusCode::OK => Ok(response), - StatusCode::NO_CONTENT => Ok(response), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ServiceError::Unauthorized) - }, - StatusCode::NOT_FOUND => { - // This is 404 and means that e.g. recipient is not registered - Err(ServiceError::NotFoundError) - }, - StatusCode::PAYLOAD_TOO_LARGE => { - // This is 413 and means rate limit exceeded for Signal. - Err(ServiceError::RateLimitExceeded) - }, - StatusCode::CONFLICT => { - let mismatched_devices = - Self::json(&mut response).await.map_err(|e| { - tracing::error!( - "Failed to decode HTTP 409 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::CONFLICT.as_u16(), - } - })?; - Err(ServiceError::MismatchedDevicesException( - mismatched_devices, - )) - }, - StatusCode::GONE => { - let stale_devices = - Self::json(&mut response).await.map_err(|e| { - tracing::error!( - "Failed to decode HTTP 410 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::GONE.as_u16(), - } - })?; - Err(ServiceError::StaleDevices(stale_devices)) - }, - StatusCode::LOCKED => { - let locked = Self::json(&mut response).await.map_err(|e| { - tracing::error!( - ?response, - "Failed to decode HTTP 423 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::LOCKED.as_u16(), - } - })?; - Err(ServiceError::Locked(locked)) - }, - StatusCode::PRECONDITION_REQUIRED => { - let proof_required = - Self::json(&mut response).await.map_err(|e| { - tracing::error!( - "Failed to decode HTTP 428 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::PRECONDITION_REQUIRED - .as_u16(), - } - })?; - Err(ServiceError::ProofRequiredError(proof_required)) - }, - // XXX: fill in rest from PushServiceSocket - code => { - tracing::trace!( - "Unhandled response {} with body: {}", - code.as_u16(), - Self::text(&mut response).await?, - ); - Err(ServiceError::UnhandledResponseCode { - http_code: code.as_u16(), - }) - }, - } - } - - async fn body( - response: &mut Response, - ) -> Result { - Ok(response - .collect() - .await - .map_err(|e| ServiceError::ResponseError { - reason: format!("failed to aggregate HTTP response body: {e}"), - })? - .aggregate()) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn json( - response: &mut Response, - ) -> Result - where - for<'de> T: Deserialize<'de>, - { - let body = Self::body(response).await?; - - if body.has_remaining() { - serde_json::from_reader(body.reader()) - } else { - serde_json::from_value(serde_json::Value::Null) - } - .map_err(|e| ServiceError::JsonDecodeError { - reason: e.to_string(), - }) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn protobuf( - response: &mut Response, - ) -> Result - where - M: ProtobufMessage + Default, - { - let body = Self::body(response).await?; - M::decode(body).map_err(ServiceError::ProtobufDecodeError) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn text( - response: &mut Response, - ) -> Result { - let body = Self::body(response).await?; - io::read_to_string(body.reader()).map_err(|e| { - ServiceError::ResponseError { - reason: format!("failed to read HTTP response body: {e}"), - } - }) + Ok(builder) } -} -impl PushService { - #[tracing::instrument(skip(self))] - pub(crate) async fn get_json( + pub async fn ws( &mut self, - service: Endpoint, path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - ) -> Result - where - for<'de> T: Deserialize<'de>, - { - let mut response = self - .request( - Method::GET, - service, - path, - additional_headers, - credentials_override, - None, - ) - .await?; + keepalive_path: &str, + additional_headers: &[(&'static str, &str)], + credentials: Option, + ) -> Result { + let span = debug_span!("websocket"); - Self::json(&mut response).await - } + let endpoint = self.cfg.base_url(Endpoint::Service); + let mut url = endpoint.join(path).expect("valid url"); + url.set_scheme("wss").expect("valid https base url"); - #[tracing::instrument(skip(self))] - async fn delete_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - ) -> Result - where - for<'de> T: Deserialize<'de>, - { - let mut response = self - .request( - Method::DELETE, - service, - path, - additional_headers, - HttpAuthOverride::NoOverride, - None, - ) - .await?; + if let Some(credentials) = credentials { + url.query_pairs_mut() + .append_pair("login", &credentials.login()) + .append_pair( + "password", + credentials.password.as_ref().expect("a password"), + ); + } - Self::json(&mut response).await - } + let mut builder = self.client.get(url); + for (key, value) in additional_headers { + builder = builder.header(*key, *value); + } - #[tracing::instrument(skip(self, value))] - pub async fn put_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::PUT, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) + let ws = builder + .upgrade() + .send() + .await? + .into_websocket() + .instrument(span.clone()) .await?; - Self::json(&mut response).await + let (ws, task) = + SignalWebSocket::from_socket(ws, keepalive_path.to_owned()); + let task = task.instrument(span); + tokio::task::spawn(task); + Ok(ws) } - #[tracing::instrument(skip(self, value))] - async fn patch_json( + pub(crate) async fn get_group( &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::PATCH, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; - - Self::json(&mut response).await + credentials: HttpAuth, + ) -> Result { + self.request( + Method::GET, + Endpoint::Storage, + "/v1/groups/", + HttpAuthOverride::Identified(credentials), + )? + .send() + .await? + .service_error_for_status() + .await? + .protobuf() + .await + .map_err(Into::into) } +} - #[tracing::instrument(skip(self, value))] - async fn post_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::POST, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; +pub(crate) mod protobuf { + use async_trait::async_trait; + use prost::{EncodeError, Message}; + use reqwest::{header, RequestBuilder, Response}; - Self::json(&mut response).await - } + use super::ServiceError; - #[tracing::instrument(skip(self))] - async fn get_protobuf( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - ) -> Result + pub(crate) trait ProtobufRequestBuilderExt where - T: Default + ProtobufMessage, + Self: Sized, { - let mut response = self - .request( - Method::GET, - service, - path, - additional_headers, - credentials_override, - None, - ) - .await?; - - Self::protobuf(&mut response).await + /// Set the request payload encoded as protobuf. + /// Sets the `Content-Type` header to `application/protobuf` + #[allow(dead_code)] + fn protobuf( + self, + value: T, + ) -> Result; } - #[tracing::instrument(skip(self, value))] - async fn put_protobuf( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - value: S, - ) -> Result - where - D: Default + ProtobufMessage, - S: Sized + ProtobufMessage, - { - let protobuf = value.encode_to_vec(); - - let mut response = self - .request( - Method::PUT, - service, - path, - additional_headers, - HttpAuthOverride::NoOverride, - Some(RequestBody { - contents: protobuf, - content_type: "application/x-protobuf".into(), - }), - ) - .await?; - - Self::protobuf(&mut response).await + #[async_trait::async_trait] + pub(crate) trait ProtobufResponseExt { + /// Get the response body decoded from Protobuf + async fn protobuf( + self, + ) -> Result; } - pub async fn ws( - &mut self, - path: &str, - keepalive_path: &str, - additional_headers: &[(&str, &str)], - credentials: Option, - ) -> Result { - let span = debug_span!("websocket"); - let (ws, stream) = TungsteniteWebSocket::with_tls_config( - Self::tls_config(&self.cfg), - self.cfg.base_url(Endpoint::Service), - path, - additional_headers, - credentials.as_ref(), - ) - .instrument(span.clone()) - .await?; - let (ws, task) = - SignalWebSocket::from_socket(ws, stream, keepalive_path.to_owned()); - let task = task.instrument(span); - tokio::task::spawn(task); - Ok(ws) + impl ProtobufRequestBuilderExt for RequestBuilder { + fn protobuf( + self, + value: T, + ) -> Result { + let mut buf = Vec::new(); + value.encode(&mut buf)?; + let this = + self.header(header::CONTENT_TYPE, "application/protobuf"); + Ok(this.body(buf)) + } } - pub(crate) async fn get_group( - &mut self, - credentials: HttpAuth, - ) -> Result { - self.get_protobuf( - Endpoint::Storage, - "/v1/groups/", - &[], - HttpAuthOverride::Identified(credentials), - ) - .await + #[async_trait] + impl ProtobufResponseExt for Response { + async fn protobuf( + self, + ) -> Result { + let body = self.bytes().await?; + let decoded = T::decode(body)?; + Ok(decoded) + } } } @@ -685,11 +332,8 @@ mod tests { let configs = &[SignalServers::Staging, SignalServers::Production]; for cfg in configs { - let _ = super::PushService::new( - cfg, - None, - "libsignal-service test".to_string(), - ); + let _ = + super::PushService::new(cfg, None, "libsignal-service test"); } } diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index 0a444fef6..a1a859ac9 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -1,3 +1,4 @@ +use reqwest::Method; use serde::{Deserialize, Serialize}; use zkgroup::profiles::{ProfileKeyCommitment, ProfileKeyVersion}; @@ -5,12 +6,12 @@ use crate::{ configuration::Endpoint, content::ServiceError, profile_cipher::ProfileCipherError, - push_service::{AvatarWrite, HttpAuthOverride}, + push_service::AvatarWrite, utils::{serde_base64, serde_optional_base64}, Profile, ServiceAddress, }; -use super::{DeviceCapabilities, PushService}; +use super::{DeviceCapabilities, HttpAuthOverride, PushService, ReqwestExt}; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -102,13 +103,19 @@ impl PushService { format!("/v1/profile/{}", address.uuid) }; // TODO: set locale to en_US - self.get_json( + self.request( + Method::GET, Endpoint::Service, &endpoint, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send() + .await? + .service_error_for_status() + .await? + .json() .await + .map_err(Into::into) } pub async fn retrieve_profile_avatar( @@ -161,35 +168,34 @@ impl PushService { }; // XXX this should be a struct; cfr ProfileAvatarUploadAttributes - let response: Result = self - .put_json( + let upload_url: Result = self + .request( + Method::PUT, Endpoint::Service, "/v1/profile", - &[], HttpAuthOverride::NoOverride, - command, - ) + )? + .json(&command) + .send() + .await? + .service_error_for_status() + .await? + .json() .await; - match (response, avatar) { - (Ok(_url), AvatarWrite::NewAvatar(_avatar)) => { + + match (upload_url, avatar) { + (_url, AvatarWrite::NewAvatar(_avatar)) => { // FIXME unreachable!("Uploading avatar unimplemented"); }, // FIXME cleanup when #54883 is stable and MSRV: // or-patterns syntax is experimental // see issue #54883 for more information - ( - Err(ServiceError::JsonDecodeError { .. }), - AvatarWrite::RetainAvatar, - ) - | ( - Err(ServiceError::JsonDecodeError { .. }), - AvatarWrite::NoAvatar, - ) => { + (Err(_), AvatarWrite::RetainAvatar) + | (Err(_), AvatarWrite::NoAvatar) => { // OWS sends an empty string when there's no attachment Ok(None) }, - (Err(e), _) => Err(e), (Ok(_resp), AvatarWrite::RetainAvatar) | (Ok(_resp), AvatarWrite::NoAvatar) => { tracing::warn!( diff --git a/src/push_service/registration.rs b/src/push_service/registration.rs index 6d353b962..2ab6ea0c2 100644 --- a/src/push_service/registration.rs +++ b/src/push_service/registration.rs @@ -1,15 +1,27 @@ +use derivative::Derivative; use libsignal_protocol::IdentityKey; +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{AccountAttributes, AuthCredentials, PushService, ServiceError}; +use super::{AccountAttributes, PushService, ServiceError}; use crate::{ configuration::Endpoint, pre_keys::{KyberPreKeyEntity, SignedPreKeyEntity}, - push_service::HttpAuthOverride, + push_service::{response::ReqwestExt, HttpAuthOverride}, utils::serde_base64, }; +/// This type is used in registration lock handling. +/// It's identical with HttpAuth, but used to avoid type confusion. +#[derive(Derivative, Clone, Serialize, Deserialize)] +#[derivative(Debug)] +pub struct AuthCredentials { + pub username: String, + #[derivative(Debug = "ignore")] + pub password: String, +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct RegistrationLockFailure { @@ -31,21 +43,13 @@ pub struct VerifyAccountResponse { pub number: Option, } -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] pub enum VerificationTransport { Sms, Voice, } -impl VerificationTransport { - pub fn as_str(&self) -> &str { - match self { - Self::Sms => "sms", - Self::Voice => "voice", - } - } -} - #[derive(Clone, Debug)] pub enum RegistrationMethod<'a> { SessionId(&'a str), @@ -148,7 +152,13 @@ impl PushService { device_activation_request: DeviceActivationRequest, } - let req = RegistrationSessionRequestBody { + self.request( + Method::POST, + Endpoint::Service, + "/v1/registration", + HttpAuthOverride::NoOverride, + )? + .json(&RegistrationSessionRequestBody { session_id: registration_method.session_id(), recovery_password: registration_method.recovery_password(), account_attributes, @@ -157,18 +167,14 @@ impl PushService { pni_identity_key: pni_identity_key.serialize().into(), device_activation_request, every_signed_key_valid: true, - }; - - let res: VerifyAccountResponse = self - .post_json( - Endpoint::Service, - "/v1/registration", - &[], - HttpAuthOverride::NoOverride, - req, - ) - .await?; - Ok(res) + }) + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } // Equivalent of Java's @@ -190,24 +196,26 @@ impl PushService { push_token_type: Option<&'a str>, } - let req = VerificationSessionMetadataRequestBody { + self.request( + Method::POST, + Endpoint::Service, + "/v1/verification/session", + HttpAuthOverride::Unidentified, + )? + .json(&VerificationSessionMetadataRequestBody { number, push_token_type: push_token.as_ref().map(|_| "fcm"), push_token, mcc, mnc, - }; - - let res: RegistrationSessionMetadataResponse = self - .post_json( - Endpoint::Service, - "/v1/verification/session", - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + }) + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } pub async fn patch_verification_session<'a>( @@ -230,25 +238,27 @@ impl PushService { push_token_type: Option<&'a str>, } - let req = UpdateVerificationSessionRequestBody { + self.request( + Method::PATCH, + Endpoint::Service, + format!("/v1/verification/session/{}", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&UpdateVerificationSessionRequestBody { captcha, push_token_type: push_token.as_ref().map(|_| "fcm"), push_token, mcc, mnc, push_challenge, - }; - - let res: RegistrationSessionMetadataResponse = self - .patch_json( - Endpoint::Service, - &format!("/v1/verification/session/{}", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + }) + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } // Equivalent of Java's @@ -271,20 +281,26 @@ impl PushService { // locale: Option, transport: VerificationTransport, ) -> Result { - let mut req = std::collections::HashMap::new(); - req.insert("transport", transport.as_str()); - req.insert("client", client); + #[derive(Debug, Serialize)] + struct VerificationCodeRequest<'a> { + transport: VerificationTransport, + client: &'a str, + } - let res: RegistrationSessionMetadataResponse = self - .post_json( - Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + self.request( + Method::POST, + Endpoint::Service, + format!("/v1/verification/session/{}/code", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&VerificationCodeRequest { transport, client }) + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } pub async fn submit_verification_code( @@ -292,18 +308,26 @@ impl PushService { session_id: &str, verification_code: &str, ) -> Result { - let mut req = std::collections::HashMap::new(); - req.insert("code", verification_code); + #[derive(Debug, Serialize)] + struct VerificationCode<'a> { + code: &'a str, + } - let res: RegistrationSessionMetadataResponse = self - .put_json( - Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + self.request( + Method::PUT, + Endpoint::Service, + format!("/v1/verification/session/{}/code", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&VerificationCode { + code: verification_code, + }) + .send() + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/push_service/response.rs b/src/push_service/response.rs new file mode 100644 index 000000000..9fc02c6e9 --- /dev/null +++ b/src/push_service/response.rs @@ -0,0 +1,160 @@ +use reqwest::StatusCode; + +use crate::proto::WebSocketResponseMessage; + +use super::ServiceError; + +async fn service_error_for_status(response: R) -> Result +where + R: SignalServiceResponse, + ServiceError: From<::Error>, +{ + match response.status_code() { + StatusCode::OK => Ok(response), + StatusCode::NO_CONTENT => Ok(response), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + Err(ServiceError::Unauthorized) + }, + StatusCode::NOT_FOUND => { + // This is 404 and means that e.g. recipient is not registered + Err(ServiceError::NotFoundError) + }, + StatusCode::PAYLOAD_TOO_LARGE => { + // This is 413 and means rate limit exceeded for Signal. + Err(ServiceError::RateLimitExceeded) + }, + StatusCode::CONFLICT => { + let mismatched_devices = + response.json().await.map_err(|error| { + tracing::error!( + %error, + "failed to decode HTTP 409 status" + ); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::CONFLICT.as_u16(), + } + })?; + Err(ServiceError::MismatchedDevicesException(mismatched_devices)) + }, + StatusCode::GONE => { + let stale_devices = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 410 status"); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::GONE.as_u16(), + } + })?; + Err(ServiceError::StaleDevices(stale_devices)) + }, + StatusCode::LOCKED => { + let locked = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 423 status"); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::LOCKED.as_u16(), + } + })?; + Err(ServiceError::Locked(locked)) + }, + StatusCode::PRECONDITION_REQUIRED => { + let proof_required = response.json().await.map_err(|error| { + tracing::error!( + %error, + "failed to decode HTTP 428 status" + ); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::PRECONDITION_REQUIRED.as_u16(), + } + })?; + Err(ServiceError::ProofRequiredError(proof_required)) + }, + // XXX: fill in rest from PushServiceSocket + code => { + let response_text = response.text().await?; + tracing::trace!(status_code =% code, body = response_text, "unhandled HTTP response"); + Err(ServiceError::UnhandledResponseCode { + http_code: code.as_u16(), + }) + }, + } +} + +#[async_trait::async_trait] +pub(crate) trait SignalServiceResponse { + type Error: std::error::Error; + + fn status_code(&self) -> StatusCode; + + async fn json(self) -> Result + where + for<'de> U: serde::Deserialize<'de>; + + async fn text(self) -> Result; +} + +#[async_trait::async_trait] +impl SignalServiceResponse for reqwest::Response { + type Error = reqwest::Error; + + fn status_code(&self) -> StatusCode { + self.status() + } + + async fn json(self) -> Result + where + for<'de> U: serde::Deserialize<'de>, + { + reqwest::Response::json(self).await + } + + async fn text(self) -> Result { + reqwest::Response::text(self).await + } +} + +#[async_trait::async_trait] +impl SignalServiceResponse for WebSocketResponseMessage { + type Error = ServiceError; + + fn status_code(&self) -> StatusCode { + StatusCode::from_u16(self.status() as u16).unwrap_or_default() + } + + async fn json(self) -> Result + where + for<'de> U: serde::Deserialize<'de>, + { + serde_json::from_slice(self.body()).map_err(Into::into) + } + + async fn text(self) -> Result { + Ok(self + .body + .map(|body| String::from_utf8_lossy(&body).to_string()) + .unwrap_or_default()) + } +} + +#[async_trait::async_trait] +pub(crate) trait ReqwestExt +where + Self: Sized, +{ + /// convenience error handler to be used in the builder-style API of `reqwest::Response` + async fn service_error_for_status( + self, + ) -> Result; +} + +#[async_trait::async_trait] +impl ReqwestExt for reqwest::Response { + async fn service_error_for_status( + self, + ) -> Result { + service_error_for_status(self).await + } +} + +impl WebSocketResponseMessage { + pub async fn service_error_for_status(self) -> Result { + service_error_for_status(self).await + } +} diff --git a/src/receiver.rs b/src/receiver.rs index ba69320f0..57016dce7 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -64,7 +64,7 @@ impl MessageReceiver { retries += 1; if retries >= MAX_DOWNLOAD_RETRIES { return Err(ServiceError::Timeout { - reason: "too many retries".into(), + reason: "too many retries", }); } }, diff --git a/src/sender.rs b/src/sender.rs index ae8a862e1..691e03939 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -68,7 +68,7 @@ pub struct SentMessage { /// Attachment specification to be used for uploading. /// /// Loose equivalent of Java's `SignalServiceAttachmentStream`. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct AttachmentSpec { pub content_type: String, pub length: usize, @@ -82,7 +82,6 @@ pub struct AttachmentSpec { pub blur_hash: Option, } -/// Equivalent of Java's `SignalServiceMessageSender`. #[derive(Clone)] pub struct MessageSender { identified_ws: SignalWebSocket, @@ -109,13 +108,18 @@ pub enum AttachmentUploadError { #[derive(thiserror::Error, Debug)] pub enum MessageSenderError { - #[error("{0}")] + #[error("service error: {0}")] ServiceError(#[from] ServiceError), + #[error("protocol error: {0}")] ProtocolError(#[from] SignalProtocolError), + #[error("Failed to upload attachment {0}")] AttachmentUploadError(#[from] AttachmentUploadError), + #[error("primary device can't send sync message {0:?}")] + SendSyncMessageError(sync_message::request::Type), + #[error("Untrusted identity key with {address:?}")] UntrustedIdentity { address: ServiceAddress }, @@ -222,9 +226,10 @@ where .get_attachment_v2_upload_attributes() .instrument(tracing::trace_span!("requesting upload attributes")) .await?; + let (id, digest) = self .service - .upload_attachment(&attrs, &mut std::io::Cursor::new(&contents)) + .upload_attachment(attrs, &mut std::io::Cursor::new(&contents)) .instrument(tracing::trace_span!("Uploading attachment")) .await?; @@ -778,13 +783,7 @@ where request_type: sync_message::request::Type, ) -> Result<(), MessageSenderError> { if self.device_id == DEFAULT_DEVICE_ID.into() { - let reason = format!( - "Primary device can't send sync requests, ignoring {:?}", - request_type - ); - return Err(MessageSenderError::ServiceError( - ServiceError::SendError { reason }, - )); + return Err(MessageSenderError::SendSyncMessageError(request_type)); } let msg = SyncMessage { diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 4736edc45..c279b1f28 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -9,18 +9,20 @@ use futures::channel::{mpsc, oneshot}; use futures::future::BoxFuture; use futures::prelude::*; use futures::stream::FuturesUnordered; -use prost::Message; -use serde::{Deserialize, Serialize}; +use reqwest::Method; +use reqwest_websocket::WebSocket; +use tokio::time::Instant; -use crate::messagepipe::{WebSocketService, WebSocketStreamItem}; use crate::proto::{ web_socket_message, WebSocketMessage, WebSocketRequestMessage, WebSocketResponseMessage, }; -use crate::push_service::{MismatchedDevices, ServiceError}; +use crate::push_service::{self, ServiceError, SignalServiceResponse}; +mod request; mod sender; -pub(crate) mod tungstenite; + +pub use request::WebSocketRequestMessageBuilder; type RequestStreamItem = ( WebSocketRequestMessage, @@ -61,7 +63,7 @@ struct SignalWebSocketInner { stream: Option, } -struct SignalWebSocketProcess { +struct SignalWebSocketProcess { /// Whether to enable keep-alive or not (and send a request to this path) keep_alive_path: String, @@ -73,7 +75,7 @@ struct SignalWebSocketProcess { /// Signal's requests should go in here, to be delivered to the application. request_sink: mpsc::Sender, - outgoing_request_map: HashMap< + outgoing_requests: HashMap< u64, oneshot::Sender>, >, @@ -84,36 +86,35 @@ struct SignalWebSocketProcess { BoxFuture<'static, Result>, >, - // WS backend stuff - ws: WS, - stream: WS::Stream, + ws: WebSocket, } -impl SignalWebSocketProcess { +impl SignalWebSocketProcess { async fn process_frame( &mut self, - frame: Bytes, + frame: Vec, ) -> Result<(), ServiceError> { - let msg = WebSocketMessage::decode(frame)?; + use prost::Message; + let msg = WebSocketMessage::decode(Bytes::from(frame))?; if let Some(request) = &msg.request { tracing::trace!( - "decoded WebSocketMessage request {{ r#type: {:?}, verb: {:?}, path: {:?}, body: {} bytes, headers: {:?}, id: {:?} }}", - msg.r#type(), + msg_type =? msg.r#type(), + request.id, request.verb, request.path, - request.body.as_ref().map(|x| x.len()).unwrap_or(0), - request.headers, - request.id, + request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0), + ?request.headers, + "decoded WebSocketMessage request" ); } else if let Some(response) = &msg.response { tracing::trace!( - "decoded WebSocketMessage response {{ r#type: {:?}, status: {:?}, message: {:?}, body: {} bytes, headers: {:?}, id: {:?} }}", - msg.r#type(), + msg_type =? msg.r#type(), response.status, response.message, - response.body.as_ref().map(|x| x.len()).unwrap_or(0), - response.headers, + response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0), + ?response.headers, response.id, + "decoded WebSocketMessage response" ); } else { tracing::debug!("decoded {msg:?}"); @@ -121,29 +122,27 @@ impl SignalWebSocketProcess { use web_socket_message::Type; match (msg.r#type(), msg.request, msg.response) { - (Type::Unknown, _, _) => Err(ServiceError::InvalidFrameError { - reason: "Unknown frame type".into(), + (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame { + reason: "unknown frame type", }), (Type::Request, Some(request), _) => { let (sink, recv) = oneshot::channel(); tracing::trace!("sending request with body"); self.request_sink.send((request, sink)).await.map_err( - |_| ServiceError::WsError { - reason: "request handler failed".into(), + |_| ServiceError::WsClosing { + reason: "request handler failed", }, )?; self.outgoing_responses.push(Box::pin(recv)); Ok(()) }, - (Type::Request, None, _) => Err(ServiceError::InvalidFrameError { - reason: "Type was request, but does not contain request." - .into(), + (Type::Request, None, _) => Err(ServiceError::InvalidFrame { + reason: "type was request, but does not contain request", }), (Type::Response, _, Some(response)) => { if let Some(id) = response.id { - if let Some(responder) = - self.outgoing_request_map.remove(&id) + if let Some(responder) = self.outgoing_requests.remove(&id) { if let Err(e) = responder.send(Ok(response)) { tracing::warn!( @@ -155,10 +154,11 @@ impl SignalWebSocketProcess { } else if let Some(_x) = self.outgoing_keep_alive_set.take(&id) { - if response.status() != 200 { + let status_code = response.status(); + if status_code != 200 { tracing::warn!( - "Response code for keep-alive is not 200: {:?}", - response + status_code, + "response code for keep-alive != 200" ); return Err(ServiceError::UnhandledResponseCode { http_code: response.status() as u16, @@ -166,17 +166,16 @@ impl SignalWebSocketProcess { } } else { tracing::warn!( - "Response for non existing request: {:?}", - response + ?response, + "response for non existing request" ); } } Ok(()) }, - (Type::Response, _, None) => Err(ServiceError::InvalidFrameError { - reason: "Type was response, but does not contain response." - .into(), + (Type::Response, _, None) => Err(ServiceError::InvalidFrame { + reason: "type was response, but does not contain response", }), } } @@ -186,83 +185,107 @@ impl SignalWebSocketProcess { let mut rng = rand::thread_rng(); loop { let id = rng.gen(); - if !self.outgoing_request_map.contains_key(&id) { + if !self.outgoing_requests.contains_key(&id) { return id; } } } async fn run(mut self) -> Result<(), ServiceError> { + let mut ka_interval = tokio::time::interval_at( + Instant::now(), + push_service::KEEPALIVE_TIMEOUT_SECONDS, + ); + loop { futures::select! { + _ = ka_interval.tick().fuse() => { + use prost::Message; + tracing::debug!("sending keep-alive"); + let request = WebSocketRequestMessage::new(Method::GET) + .id(self.next_request_id()) + .path(&self.keep_alive_path) + .build(); + self.outgoing_keep_alive_set.insert(request.id.unwrap()); + let msg = WebSocketMessage { + r#type: Some(web_socket_message::Type::Request.into()), + request: Some(request), + ..Default::default() + }; + let buffer = msg.encode_to_vec(); + if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await { + tracing::info!("Websocket sink has closed: {:?}.", e); + break; + }; + }, // Process requests from the application, forward them to Signal x = self.requests.next() => { match x { Some((mut request, responder)) => { + use prost::Message; + // Regenerate ID if already in the table request.id = Some( request .id - .filter(|x| !self.outgoing_request_map.contains_key(x)) + .filter(|x| !self.outgoing_requests.contains_key(x)) .unwrap_or_else(|| self.next_request_id()), ); tracing::trace!( - "sending WebSocketRequestMessage {{ verb: {:?}, path: {:?}, body (bytes): {:?}, headers: {:?}, id: {:?} }}", + request.id, request.verb, request.path, - request.body.as_ref().map(|x| x.len()), - request.headers, - request.id, + request_body_size_bytes = request.body.as_ref().map(|x| x.len()), + ?request.headers, + "sending WebSocketRequestMessage", ); - self.outgoing_request_map.insert(request.id.unwrap(), responder); + self.outgoing_requests.insert(request.id.unwrap(), responder); let msg = WebSocketMessage { r#type: Some(web_socket_message::Type::Request.into()), request: Some(request), ..Default::default() }; let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await? + self.ws.send(reqwest_websocket::Message::Binary(buffer)).await? } None => { - return Err(ServiceError::WsError { - reason: "SignalWebSocket: end of application request stream; socket closing".into() + return Err(ServiceError::WsClosing { + reason: "end of application request stream; socket closing" }); } } } - web_socket_item = self.stream.next() => { + // Incoming websocket message + web_socket_item = self.ws.next().fuse() => { + use reqwest_websocket::Message; match web_socket_item { - Some(WebSocketStreamItem::Message(frame)) => { + Some(Ok(Message::Close { code, reason })) => { + tracing::warn!(%code, reason, "websocket closed"); + break; + }, + Some(Ok(Message::Binary(frame))) => { self.process_frame(frame).await?; } - Some(WebSocketStreamItem::KeepAliveRequest) => { - // XXX: would be nicer if we could drop this request into the request - // queue above. - tracing::debug!("Sending keep alive upon request"); - let request = WebSocketRequestMessage { - id: Some(self.next_request_id()), - path: Some(self.keep_alive_path.clone()), - verb: Some("GET".into()), - ..Default::default() - }; - self.outgoing_keep_alive_set.insert(request.id.unwrap()); - let msg = WebSocketMessage { - r#type: Some(web_socket_message::Type::Request.into()), - request: Some(request), - ..Default::default() - }; - let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await?; + Some(Ok(Message::Ping(_))) => { + tracing::trace!("received ping"); } + Some(Ok(Message::Pong(_))) => { + tracing::trace!("received pong"); + } + Some(Ok(Message::Text(_))) => { + tracing::trace!("received text (unsupported, skipping)"); + } + Some(Err(e)) => return Err(ServiceError::WsError(e)), None => { - return Err(ServiceError::WsError { - reason: "end of web request stream; socket closing".into() + return Err(ServiceError::WsClosing { + reason: "end of web request stream; socket closing" }); } } } response = self.outgoing_responses.next() => { + use prost::Message; match response { Some(Ok(response)) => { tracing::trace!("sending response {:?}", response); @@ -273,10 +296,10 @@ impl SignalWebSocketProcess { ..Default::default() }; let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await?; + self.ws.send(buffer.into()).await?; } - Some(Err(e)) => { - tracing::error!("could not generate response to a Signal request; responder was canceled: {}. Continuing.", e); + Some(Err(error)) => { + tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing."); } None => { unreachable!("outgoing responses should never fuse") @@ -285,6 +308,7 @@ impl SignalWebSocketProcess { } } } + Ok(()) } } @@ -293,9 +317,8 @@ impl SignalWebSocket { self.inner.lock().unwrap() } - pub fn from_socket( - ws: WS, - stream: WS::Stream, + pub fn from_socket( + ws: WebSocket, keep_alive_path: String, ) -> (Self, impl Future) { // Create process @@ -306,7 +329,7 @@ impl SignalWebSocket { keep_alive_path, requests: outgoing_requests, request_sink: incoming_request_sink, - outgoing_request_map: HashMap::default(), + outgoing_requests: HashMap::default(), outgoing_keep_alive_set: HashSet::new(), // Initializing the FuturesUnordered with a `pending` future means it will never fuse // itself, so an "empty" FuturesUnordered will still allow new futures to be added. @@ -316,7 +339,6 @@ impl SignalWebSocket { .into_iter() .collect(), ws, - stream, }; let process = process.run().map(|x| match x { Ok(()) => (), @@ -389,15 +411,14 @@ impl SignalWebSocket { async move { if let Err(_e) = request_sink.send((r, sink)).await { return Err(ServiceError::WsClosing { - reason: "WebSocket closing while sending request.".into(), + reason: "WebSocket closing while sending request", }); } // Handle the oneshot sender error for dropped senders. match recv.await { Ok(x) => x, Err(_) => Err(ServiceError::WsClosing { - reason: "WebSocket closing while waiting for a response." - .into(), + reason: "WebSocket closing while waiting for a response", }), } } @@ -410,116 +431,12 @@ impl SignalWebSocket { where for<'de> T: serde::Deserialize<'de>, { - let response = self.request(r).await?; - if response.status() != 200 { - tracing::debug!( - "request_json with non-200 status code. message: {}", - response.message() - ); - } - - fn json(body: &[u8]) -> Result - where - for<'de> U: serde::Deserialize<'de>, - { - serde_json::from_slice(body).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - }) - } - - match response.status() { - 200 | 204 => json(response.body()), - 401 | 403 => Err(ServiceError::Unauthorized), - 404 => Err(ServiceError::NotFoundError), - 413 /* PAYLOAD_TOO_LARGE */ => Err(ServiceError::RateLimitExceeded) , - 409 /* CONFLICT */ => { - let mismatched_devices: MismatchedDevices = - json(response.body()).map_err(|e| { - tracing::error!( - "Failed to decode HTTP 409 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: 409, - } - })?; - Err(ServiceError::MismatchedDevicesException( - mismatched_devices, - )) - }, - 410 /* GONE */ => { - let stale_devices = - json(response.body()).map_err(|e| { - tracing::error!( - "Failed to decode HTTP 410 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: 410, - } - })?; - Err(ServiceError::StaleDevices(stale_devices)) - }, - 423 /* LOCKED */ => { - let locked = json(response.body()).map_err(|e| { - tracing::error!("Failed to decode HTTP 423 response: {}", e); - ServiceError::UnhandledResponseCode { - http_code: 423, - } - })?; - Err(ServiceError::Locked(locked)) - }, - 428 /* PRECONDITION_REQUIRED */ => { - let proof_required = json(response.body()).map_err(|e| { - tracing::error!("Failed to decode HTTP 428 response: {}", e); - ServiceError::UnhandledResponseCode { - http_code: 428, - } - })?; - Err(ServiceError::ProofRequiredError(proof_required)) - }, - _ => Err(ServiceError::UnhandledResponseCode { - http_code: response.status() as u16, - }), - } - } - - pub(crate) async fn put_json( - &mut self, - path: &str, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Serialize, - { - self.put_json_with_headers(path, value, vec![]).await - } - - pub(crate) async fn put_json_with_headers<'h, D, S>( - &mut self, - path: &str, - value: S, - mut extra_headers: Vec, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Serialize, - { - extra_headers.push("content-type:application/json".into()); - let request = WebSocketRequestMessage { - path: Some(path.into()), - verb: Some("PUT".into()), - headers: extra_headers, - body: Some(serde_json::to_vec(&value).map_err(|e| { - ServiceError::SendError { - reason: format!("Serializing JSON {}", e), - } - })?), - ..Default::default() - }; - self.request_json(request).await + self.request(r) + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/websocket/request.rs b/src/websocket/request.rs new file mode 100644 index 000000000..d10306f29 --- /dev/null +++ b/src/websocket/request.rs @@ -0,0 +1,52 @@ +use reqwest::Method; +use serde::Serialize; + +use crate::proto::WebSocketRequestMessage; + +#[derive(Debug)] +pub struct WebSocketRequestMessageBuilder { + request: WebSocketRequestMessage, +} + +impl WebSocketRequestMessage { + #[allow(clippy::new_ret_no_self)] + pub fn new(method: Method) -> WebSocketRequestMessageBuilder { + WebSocketRequestMessageBuilder { + request: WebSocketRequestMessage { + verb: Some(method.to_string()), + ..Default::default() + }, + } + } +} + +impl WebSocketRequestMessageBuilder { + pub fn id(mut self, id: u64) -> Self { + self.request.id = Some(id); + self + } + + pub fn path(mut self, path: impl Into) -> Self { + self.request.path = Some(path.into()); + self + } + + pub fn header(mut self, key: &str, value: impl AsRef) -> Self { + self.request + .headers + .push(format!("{key}:{}", value.as_ref())); + self + } + + pub fn json( + mut self, + value: S, + ) -> Result { + self.request.body = Some(serde_json::to_vec(&value)?); + Ok(self.header("content-type", "application/json").request) + } + + pub fn build(self) -> WebSocketRequestMessage { + self.request + } +} diff --git a/src/websocket/sender.rs b/src/websocket/sender.rs index e9ea90ef6..394cf4c12 100644 --- a/src/websocket/sender.rs +++ b/src/websocket/sender.rs @@ -12,8 +12,10 @@ impl SignalWebSocket { &mut self, messages: OutgoingPushMessages, ) -> Result { - let path = format!("/v1/messages/{}", messages.destination); - self.put_json(&path, messages).await + let request = WebSocketRequestMessage::new(Method::PUT) + .path(format!("/v1/messages/{}", messages.destination)) + .json(&messages)?; + self.request_json(request).await } pub async fn send_messages_unidentified( @@ -21,12 +23,13 @@ impl SignalWebSocket { messages: OutgoingPushMessages, access: &UnidentifiedAccess, ) -> Result { - let path = format!("/v1/messages/{}", messages.destination); - let header = format!( - "Unidentified-Access-Key:{}", - BASE64_RELAXED.encode(&access.key) - ); - self.put_json_with_headers(&path, messages, vec![header]) - .await + let request = WebSocketRequestMessage::new(Method::PUT) + .path(format!("/v1/messages/{}", messages.destination)) + .header( + "Unidentified-Access-Key:{}", + BASE64_RELAXED.encode(&access.key), + ) + .json(&messages)?; + self.request_json(request).await } } diff --git a/src/websocket/tungstenite.rs b/src/websocket/tungstenite.rs deleted file mode 100644 index ae40caf09..000000000 --- a/src/websocket/tungstenite.rs +++ /dev/null @@ -1,202 +0,0 @@ -use std::sync::Arc; - -use async_tungstenite::{ - tokio::connect_async_with_tls_connector, - tungstenite::{ - client::IntoClientRequest, - http::{HeaderName, StatusCode}, - Error as TungsteniteError, Message, - }, -}; -use bytes::Bytes; -use futures::{channel::mpsc::*, prelude::*}; -use tokio::time::Instant; -use tokio_rustls::rustls; -use url::Url; - -use crate::{ - configuration::ServiceCredentials, - push_service::{self, ServiceError}, -}; - -use crate::messagepipe::{WebSocketService, WebSocketStreamItem}; - -pub struct TungsteniteWebSocket { - socket_sink: - Box + Send + Unpin>, -} - -#[derive(thiserror::Error, Debug)] -pub enum TungsteniteWebSocketError { - #[error("error while connecting to websocket: {0}")] - ConnectionError(#[from] TungsteniteError), -} - -impl From for ServiceError { - fn from(e: TungsteniteWebSocketError) -> Self { - match e { - TungsteniteWebSocketError::ConnectionError( - TungsteniteError::Http(response), - ) => match response.status() { - StatusCode::FORBIDDEN => ServiceError::Unauthorized, - s => ServiceError::WsError { - reason: format!("HTTP status {}", s), - }, - }, - e => ServiceError::WsError { - reason: e.to_string(), - }, - } - } -} - -// Process the WebSocket, until it times out. -async fn process( - socket_stream: S, - mut incoming_sink: Sender, -) -> Result<(), TungsteniteWebSocketError> -where - S: Stream> + Unpin, -{ - let mut socket_stream = socket_stream.fuse(); - - let mut ka_interval = tokio::time::interval_at( - Instant::now(), - push_service::KEEPALIVE_TIMEOUT_SECONDS, - ); - - loop { - tokio::select! { - _ = ka_interval.tick() => { - tracing::trace!("Triggering keep-alive"); - if let Err(e) = incoming_sink.send(WebSocketStreamItem::KeepAliveRequest).await { - tracing::info!("Websocket sink has closed: {:?}.", e); - break; - }; - }, - frame = socket_stream.next() => { - let frame = if let Some(frame) = frame { - frame - } else { - tracing::info!("process: Socket stream ended"); - break; - }; - - let frame = match frame? { - Message::Binary(s) => s, - Message::Ping(msg) => { - tracing::warn!("Received Ping({:?})", msg); - - continue; - }, - Message::Pong(msg) => { - tracing::trace!("Received Pong({:?})", msg); - - continue; - }, - Message::Text(frame) => { - tracing::warn!("Message::Text {:?}", frame); - - // this is a protocol violation, maybe break; is better? - continue; - }, - - Message::Close(c) => { - tracing::warn!("Websocket closing: {:?}", c); - - break; - }, - Message::Frame(_f) => unreachable!("handled internally in Tungstenite") - }; - - // Match SendError - if let Err(e) = incoming_sink.send(WebSocketStreamItem::Message(Bytes::from(frame))).await { - tracing::info!("Websocket sink has closed: {:?}.", e); - break; - } - }, - } - } - Ok(()) -} - -impl TungsteniteWebSocket { - pub(crate) async fn with_tls_config( - tls_config: rustls::ClientConfig, - base_url: impl std::borrow::Borrow, - path: &str, - additional_headers: &[(&str, &str)], - credentials: Option<&ServiceCredentials>, - ) -> Result< - (Self, ::Stream), - TungsteniteWebSocketError, - > { - let mut url = base_url.borrow().join(path).expect("valid url"); - url.set_scheme("wss").expect("valid https base url"); - - let tls_connector = - tokio_rustls::TlsConnector::from(Arc::new(tls_config)); - - if let Some(credentials) = credentials { - url.query_pairs_mut() - .append_pair("login", &credentials.login()) - .append_pair( - "password", - credentials.password.as_ref().expect("a password"), - ); - } - - tracing::trace!("Will start websocket at {:?}", url); - - let mut request = url.into_client_request()?; - - for (key, value) in additional_headers { - request.headers_mut().insert( - // FromStr is implemnted for HeaderName, but that expects a &'static str... - HeaderName::from_bytes(key.as_bytes()) - .expect("valid header name"), - value.parse().expect("valid header value"), - ); - } - - let (socket_stream, response) = - connect_async_with_tls_connector(request, Some(tls_connector)) - .await?; - - tracing::debug!("WebSocket connected: {:?}", response); - - let (incoming_sink, incoming_stream) = channel(5); - - let (socket_sink, socket_stream) = socket_stream.split(); - let processing_task = process(socket_stream, incoming_sink); - - // When the processing_task stops, the consuming stream and sink also - // terminate. - tokio::spawn(processing_task.map(|v| match v { - Ok(()) => (), - Err(e) => { - tracing::warn!("Processing task terminated with error: {:?}", e) - }, - })); - - Ok(( - Self { - socket_sink: Box::new(socket_sink), - }, - incoming_stream, - )) - } -} - -#[async_trait::async_trait] -impl WebSocketService for TungsteniteWebSocket { - type Stream = Receiver; - - async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError> { - self.socket_sink - .send(Message::Binary(msg.to_vec())) - .await - .map_err(TungsteniteWebSocketError::from)?; - Ok(()) - } -}