Skip to content

Commit

Permalink
chore: upgrade to hyper 1.0 (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
boxdot authored Jul 29, 2024
1 parent 86dd9da commit bcb2821
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 57 deletions.
16 changes: 8 additions & 8 deletions libsignal-service-hyper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ serde_json = "1.0"
thiserror = "1.0"
url = "2.1"

# hyper rustls 0.25 is not compatible with hyper 1 yet
# https://github.com/rustls/hyper-rustls/issues/234
hyper = { version = "0.14", features = ["client", "stream"] }
hyper-rustls = { version = "0.25", features=["http1", "http2"] }
hyper-timeout = "0.4"
headers = "0.3"
hyper = "1.0"
hyper-util = { version = "0.1", features = ["client", "client-legacy"] }
hyper-rustls = { version = "0.27", features = ["http1", "http2"] }
hyper-timeout = "0.5"
headers = "0.4"
http-body-util = "0.1"

# for websocket support
async-tungstenite = { version = "0.24", features = ["tokio-rustls-native-certs"] }
async-tungstenite = { version = "0.27", features = ["tokio-rustls-native-certs", "url"] }

tokio = { version = "1.0", features = ["macros"] }
tokio-rustls = "0.25"
tokio-rustls = "0.26"

rustls-pemfile = "2.0"

Expand Down
97 changes: 48 additions & 49 deletions libsignal-service-hyper/src/push_service.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
use std::{io::Read, time::Duration};
use std::io;
use std::time::Duration;

use bytes::{Buf, Bytes};
use futures::{FutureExt, StreamExt, TryStreamExt};
use headers::{Authorization, HeaderMapExt};
use http_body_util::{BodyExt, Full};
use hyper::{
client::HttpConnector,
body::Incoming,
header::{CONTENT_LENGTH, CONTENT_TYPE, USER_AGENT},
Body, Client, Method, Request, Response, StatusCode,
Method, Request, Response, StatusCode,
};
use hyper_rustls::HttpsConnector;
use hyper_timeout::TimeoutConnector;
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
use libsignal_service::{
configuration::*, prelude::ProtobufMessage, push_service::*,
websocket::SignalWebSocket, MaybeSend,
};
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls;
use tokio_rustls::rustls::{self, ClientConfig};
use tracing::{debug, debug_span};
use tracing_futures::Instrument;

use crate::websocket::TungsteniteWebSocket;
Expand All @@ -25,7 +32,8 @@ pub struct HyperPushService {
cfg: ServiceConfiguration,
user_agent: String,
credentials: Option<HttpAuth>,
client: Client<TimeoutConnector<HttpsConnector<HttpConnector>>>,
client:
Client<TimeoutConnector<HttpsConnector<HttpConnector>>, Full<Bytes>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -55,8 +63,8 @@ impl HyperPushService {
timeout_connector.set_read_timeout(Some(Duration::from_secs(65)));
timeout_connector.set_write_timeout(Some(Duration::from_secs(65)));

let client: Client<_, hyper::Body> =
Client::builder().build(timeout_connector);
let client: Client<_, Full<Bytes>> =
Client::builder(TokioExecutor::new()).build(timeout_connector);

Self {
cfg,
Expand All @@ -66,8 +74,8 @@ impl HyperPushService {
}
}

fn tls_config(cfg: &ServiceConfiguration) -> rustls::ClientConfig {
let mut cert_bytes = std::io::Cursor::new(&cfg.certificate_authority);
fn tls_config(cfg: &ServiceConfiguration) -> ClientConfig {
let mut cert_bytes = io::Cursor::new(&cfg.certificate_authority);
let roots = rustls_pemfile::certs(&mut cert_bytes);

let mut root_certs = rustls::RootCertStore::empty();
Expand All @@ -89,7 +97,7 @@ impl HyperPushService {
additional_headers: &[(&str, &str)],
credentials_override: HttpAuthOverride,
body: Option<RequestBody>,
) -> Result<Response<Body>, ServiceError> {
) -> Result<Response<Incoming>, ServiceError> {
let url = self.cfg.base_url(endpoint).join(path.as_ref())?;
let mut builder = Request::builder()
.method(method)
Expand Down Expand Up @@ -128,10 +136,10 @@ impl HyperPushService {
builder
.header(CONTENT_LENGTH, contents.len() as u64)
.header(CONTENT_TYPE, content_type)
.body(Body::from(contents))
.body(Full::new(Bytes::from(contents)))
.unwrap()
} else {
builder.body(Body::empty()).unwrap()
builder.body(Full::default()).unwrap()
};

let mut response = self.client.request(request).await.map_err(|e| {
Expand Down Expand Up @@ -223,19 +231,26 @@ impl HyperPushService {
}
}

async fn body(
response: &mut Response<Incoming>,
) -> Result<impl Buf, ServiceError> {
Ok(response
.collect()
.await
.map_err(|e| ServiceError::ResponseError {
reason: format!("failed to aggregate HTTP response body: {e}"),
})?
.aggregate())
}

#[tracing::instrument(skip(response), fields(status = %response.status()))]
async fn json<T>(response: &mut Response<Body>) -> Result<T, ServiceError>
async fn json<T>(
response: &mut Response<Incoming>,
) -> Result<T, ServiceError>
where
for<'de> T: Deserialize<'de>,
{
let body = hyper::body::aggregate(response).await.map_err(|e| {
ServiceError::ResponseError {
reason: format!(
"failed to aggregate HTTP response body: {}",
e
),
}
})?;
let body = Self::body(response).await?;

if body.has_remaining() {
serde_json::from_reader(body.reader())
Expand All @@ -249,42 +264,25 @@ impl HyperPushService {

#[tracing::instrument(skip(response), fields(status = %response.status()))]
async fn protobuf<M>(
response: &mut Response<Body>,
response: &mut Response<Incoming>,
) -> Result<M, ServiceError>
where
M: ProtobufMessage + Default,
{
let body = hyper::body::aggregate(response).await.map_err(|e| {
ServiceError::ResponseError {
reason: format!(
"failed to aggregate HTTP response body: {}",
e
),
}
})?;

let body = Self::body(response).await?;
M::decode(body).map_err(ServiceError::ProtobufDecodeError)
}

#[tracing::instrument(skip(response), fields(status = %response.status()))]
async fn text(
response: &mut Response<Body>,
response: &mut Response<Incoming>,
) -> Result<String, ServiceError> {
let body = hyper::body::aggregate(response).await.map_err(|e| {
ServiceError::ResponseError {
reason: format!(
"failed to aggregate HTTP response body: {}",
e
),
}
})?;
let mut text = String::new();
body.reader().read_to_string(&mut text).map_err(|e| {
let body = Self::body(response).await?;
io::read_to_string(body.reader()).map_err(|e| {
ServiceError::ResponseError {
reason: format!("failed to read HTTP response body: {}", e),
reason: format!("failed to read HTTP response body: {e}"),
}
})?;
Ok(text)
})
}
}

Expand Down Expand Up @@ -527,13 +525,14 @@ impl PushService for HyperPushService {
Ok(Box::new(
response
.into_body()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_data_stream()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
.into_async_read(),
))
}

#[tracing::instrument(skip(self, value, file), fields(file = file.as_ref().map(|_| "")))]
async fn post_to_cdn0<'s, C: std::io::Read + Send + 's>(
async fn post_to_cdn0<'s, C: io::Read + Send + 's>(
&mut self,
path: &str,
value: &[(&str, &str)],
Expand Down Expand Up @@ -597,7 +596,7 @@ impl PushService for HyperPushService {
)
.await?;

tracing::debug!("HyperPushService::PUT response: {:?}", response);
debug!("HyperPushService::PUT response: {:?}", response);

Ok(())
}
Expand All @@ -609,7 +608,7 @@ impl PushService for HyperPushService {
additional_headers: &[(&str, &str)],
credentials: Option<ServiceCredentials>,
) -> Result<SignalWebSocket, ServiceError> {
let span = tracing::debug_span!("websocket");
let span = debug_span!("websocket");
let (ws, stream) = TungsteniteWebSocket::with_tls_config(
Self::tls_config(&self.cfg),
self.cfg.base_url(Endpoint::Service),
Expand Down

0 comments on commit bcb2821

Please sign in to comment.