diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 128956eac..76d7d362b 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -9,8 +9,8 @@ use futures::channel::{mpsc, oneshot}; use futures::future::BoxFuture; use futures::prelude::*; use futures::stream::FuturesUnordered; +use reqwest::Method; use reqwest_websocket::WebSocket; -use serde::{Deserialize, Serialize}; use tokio::time::Instant; use crate::proto::{ @@ -19,8 +19,10 @@ use crate::proto::{ }; use crate::push_service::{self, ServiceError, SignalServiceResponse}; +mod request; mod sender; -// pub(crate) mod tungstenite; + +pub use request::WebSocketRequestMessageBuilder; type RequestStreamItem = ( WebSocketRequestMessage, @@ -202,12 +204,10 @@ impl SignalWebSocketProcess { _ = ka_interval.tick().fuse() => { use prost::Message; tracing::debug!("sending keep-alive"); - let request = WebSocketRequestMessage { - id: Some(self.next_request_id()), - path: Some(self.keep_alive_path.clone()), - verb: Some("GET".into()), - ..Default::default() - }; + let request = WebSocketRequestMessage::new(Method::GET) + .id(self.next_request_id()) + .path(&self.keep_alive_path) + .build(); self.outgoing_keep_alive_set.insert(request.id.unwrap()); let msg = WebSocketMessage { r#type: Some(web_socket_message::Type::Request.into()), @@ -441,29 +441,4 @@ impl SignalWebSocket { .await .map_err(Into::into) } - - pub(crate) async fn put_json<'h, D, S>( - &mut self, - path: &str, - value: S, - mut extra_headers: Vec, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Serialize, - { - extra_headers.push("content-type:application/json".into()); - let request = WebSocketRequestMessage { - path: Some(path.into()), - verb: Some("PUT".into()), - headers: extra_headers, - body: Some(serde_json::to_vec(&value).map_err(|e| { - ServiceError::SendError { - reason: format!("Serializing JSON {}", e), - } - })?), - ..Default::default() - }; - self.request_json(request).await - } } diff --git a/src/websocket/request.rs b/src/websocket/request.rs new file mode 100644 index 000000000..f7a7b01fb --- /dev/null +++ b/src/websocket/request.rs @@ -0,0 +1,51 @@ +use reqwest::Method; +use serde::Serialize; + +use crate::proto::WebSocketRequestMessage; + +#[derive(Debug)] +pub struct WebSocketRequestMessageBuilder { + request: WebSocketRequestMessage, +} + +impl WebSocketRequestMessage { + pub fn new(method: Method) -> WebSocketRequestMessageBuilder { + WebSocketRequestMessageBuilder { + request: WebSocketRequestMessage { + verb: Some(method.to_string()), + ..Default::default() + }, + } + } +} + +impl WebSocketRequestMessageBuilder { + pub fn id(mut self, id: u64) -> Self { + self.request.id = Some(id); + self + } + + pub fn path(mut self, path: impl Into) -> Self { + self.request.path = Some(path.into()); + self + } + + pub fn header(mut self, key: &str, value: impl AsRef) -> Self { + self.request + .headers + .push(format!("{key}={}", value.as_ref())); + self + } + + pub fn json( + mut self, + value: S, + ) -> Result { + self.request.body = Some(serde_json::to_vec(&value)?); + Ok(self.header("Content-Type", "application/json").request) + } + + pub fn build(self) -> WebSocketRequestMessage { + self.request + } +} diff --git a/src/websocket/sender.rs b/src/websocket/sender.rs index 341430b38..394cf4c12 100644 --- a/src/websocket/sender.rs +++ b/src/websocket/sender.rs @@ -12,8 +12,10 @@ impl SignalWebSocket { &mut self, messages: OutgoingPushMessages, ) -> Result { - let path = format!("/v1/messages/{}", messages.destination); - self.put_json(&path, messages, vec![]).await + let request = WebSocketRequestMessage::new(Method::PUT) + .path(format!("/v1/messages/{}", messages.destination)) + .json(&messages)?; + self.request_json(request).await } pub async fn send_messages_unidentified( @@ -21,11 +23,13 @@ impl SignalWebSocket { messages: OutgoingPushMessages, access: &UnidentifiedAccess, ) -> Result { - let path = format!("/v1/messages/{}", messages.destination); - let header = format!( - "Unidentified-Access-Key:{}", - BASE64_RELAXED.encode(&access.key) - ); - self.put_json(&path, messages, vec![header]).await + let request = WebSocketRequestMessage::new(Method::PUT) + .path(format!("/v1/messages/{}", messages.destination)) + .header( + "Unidentified-Access-Key:{}", + BASE64_RELAXED.encode(&access.key), + ) + .json(&messages)?; + self.request_json(request).await } }