Skip to content

Commit

Permalink
chore(deps): update tokio-tungstenite requirement from 0.24 to 0.25
Browse files Browse the repository at this point in the history
  • Loading branch information
chrislearn committed Dec 17, 2024
1 parent 46f624a commit b7cfd65
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ tokio-native-tls = "0.3"
tokio-rustls = {version = "0.26", default-features = false }
tokio-openssl = "0.6"
tokio-stream = { version = "0.1", default-features = false }
tokio-tungstenite = { version = "0.24", default-features = false }
tokio-tungstenite = { version = "0.25", default-features = false }
tokio-util = "0.7"
tower = { version = "0.5", default-features = false }
tracing-subscriber = { version = "0.3" }
Expand Down
119 changes: 73 additions & 46 deletions crates/extra/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
//! // client disconnected
//! return;
//! };
//!
//!
//! if ws.send(msg).await.is_err() {
//! // client disconnected
//! return;
Expand All @@ -47,11 +47,11 @@
//! #[tokio::main]
//! async fn main() {
//! let router = Router::new().get(index).push(Router::with_path("ws").goal(connect));
//!
//!
//! let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
//! Server::new(acceptor).serve(router).await;
//! }
//!
//!
//! static INDEX_HTML: &str = r#"<!DOCTYPE html>
//! <html>
//! <head>
Expand Down Expand Up @@ -81,21 +81,22 @@ use std::borrow::Cow;
use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, ready};
use std::task::{ready, Context, Poll};

use futures_util::sink::{Sink, SinkExt};
use futures_util::stream::{Stream, StreamExt};
use futures_util::{future, FutureExt, TryFutureExt};
use hyper::upgrade::OnUpgrade;
use salvo_core::http::header::{SEC_WEBSOCKET_VERSION, UPGRADE};
use salvo_core::http::headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
use salvo_core::http::headers::{
Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade,
};
use salvo_core::http::{StatusCode, StatusError};
use salvo_core::rt::tokio::TokioIo;
use salvo_core::{Error, Request, Response};
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
};
use tokio_tungstenite::tungstenite::protocol::frame::{Payload, Utf8Payload};
use tokio_tungstenite::tungstenite::protocol::{self, WebSocketConfig};
use tokio_tungstenite::WebSocketStream;

/// Creates a WebSocket Handler.
/// Request:
Expand Down Expand Up @@ -132,7 +133,9 @@ impl WebSocketUpgrade {
/// Create new `WebSocketUpgrade` with config.
#[inline]
pub fn with_config(config: WebSocketConfig) -> Self {
WebSocketUpgrade { config: Some(config) }
WebSocketUpgrade {
config: Some(config),
}
}

/// The target minimum size of the write buffer to reach before writing the data
Expand All @@ -143,7 +146,9 @@ impl WebSocketUpgrade {
/// It is often more optimal to allow them to buffer a little, hence the default value.
#[inline]
pub fn write_buffer_size(mut self, max: usize) -> Self {
self.config.get_or_insert_with(WebSocketConfig::default).write_buffer_size = max;
self.config
.get_or_insert_with(WebSocketConfig::default)
.write_buffer_size = max;
self
}

Expand All @@ -159,7 +164,9 @@ impl WebSocketUpgrade {
/// and probably a little more depending on error handling strategy.
#[inline]
pub fn max_write_buffer_size(mut self, max: usize) -> Self {
self.config.get_or_insert_with(WebSocketConfig::default).max_write_buffer_size = max;
self.config
.get_or_insert_with(WebSocketConfig::default)
.max_write_buffer_size = max;
self
}

Expand All @@ -180,7 +187,9 @@ impl WebSocketUpgrade {
/// by a malicious user.
#[inline]
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.get_or_insert_with(WebSocketConfig::default).max_frame_size = Some(max);
self.config
.get_or_insert_with(WebSocketConfig::default)
.max_frame_size = Some(max);
self
}

Expand All @@ -191,13 +200,19 @@ impl WebSocketUpgrade {
/// By default this option is set to `false`, i.e. according to RFC 6455.
#[inline]
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
self.config.get_or_insert_with(WebSocketConfig::default).accept_unmasked_frames = accept;
self.config
.get_or_insert_with(WebSocketConfig::default)
.accept_unmasked_frames = accept;
self
}


/// Upgrade websocket request.
pub async fn upgrade<F, Fut>(&self, req: &mut Request, res: &mut Response, callback: F) -> Result<(), StatusError>
pub async fn upgrade<F, Fut>(
&self,
req: &mut Request,
res: &mut Response,
callback: F,
) -> Result<(), StatusError>
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
Expand All @@ -218,7 +233,8 @@ impl WebSocketUpgrade {
.unwrap_or(false);
if !matched {
tracing::debug!("missing upgrade header or it is not equal websocket");
return Err(StatusError::bad_request().brief("Missing upgrade header or it is not equal websocket."));
return Err(StatusError::bad_request()
.brief("Missing upgrade header or it is not equal websocket."));
}
let matched = !req_headers
.get(SEC_WEBSOCKET_VERSION)
Expand All @@ -233,14 +249,16 @@ impl WebSocketUpgrade {
key
} else {
tracing::debug!("sec_websocket_key is not exist in request headers");
return Err(StatusError::bad_request().brief("sec_websocket_key is not exist in request headers."));
return Err(StatusError::bad_request()
.brief("sec_websocket_key is not exist in request headers."));
};

res.status_code(StatusCode::SWITCHING_PROTOCOLS);

res.headers_mut().typed_insert(Connection::upgrade());
res.headers_mut().typed_insert(Upgrade::websocket());
res.headers_mut().typed_insert(SecWebsocketAccept::from(sec_ws_key));
res.headers_mut()
.typed_insert(SecWebsocketAccept::from(sec_ws_key));

if let Some(on_upgrade) = req.extensions_mut().remove::<OnUpgrade>() {
let config = self.config;
Expand All @@ -257,7 +275,8 @@ impl WebSocketUpgrade {
Ok(())
} else {
tracing::debug!("websocket couldn't be upgraded since no upgrade state was present");
Err(StatusError::bad_request().brief("Websocket couldn't be upgraded since no upgrade state was present."))
Err(StatusError::bad_request()
.brief("Websocket couldn't be upgraded since no upgrade state was present."))
}
}
}
Expand Down Expand Up @@ -326,22 +345,30 @@ impl Sink<Message> for WebSocket {

#[inline]
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::other)
Pin::new(&mut self.inner)
.poll_ready(cx)
.map_err(Error::other)
}

#[inline]
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.inner).start_send(item.inner).map_err(Error::other)
Pin::new(&mut self.inner)
.start_send(item.inner)
.map_err(Error::other)
}

#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::other)
Pin::new(&mut self.inner)
.poll_flush(cx)
.map_err(Error::other)
}

#[inline]
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Error::other)
Pin::new(&mut self.inner)
.poll_close(cx)
.map_err(Error::other)
}
}

Expand All @@ -364,31 +391,31 @@ pub struct Message {
impl Message {
/// Construct a new Text `Message`.
#[inline]
pub fn text<S: Into<String>>(s: S) -> Message {
pub fn text<S: Into<Utf8Payload>>(s: S) -> Message {
Message {
inner: protocol::Message::text(s),
}
}

/// Construct a new Binary `Message`.
#[inline]
pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
pub fn binary<V: Into<Payload>>(v: V) -> Message {
Message {
inner: protocol::Message::binary(v),
}
}

/// Construct a new Ping `Message`.
#[inline]
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
pub fn ping<V: Into<Payload>>(v: V) -> Message {
Message {
inner: protocol::Message::Ping(v.into()),
}
}

/// Construct a new Pong `Message`.
#[inline]
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
pub fn pong<V: Into<Payload>>(v: V) -> Message {
Message {
inner: protocol::Message::Pong(v.into()),
}
Expand Down Expand Up @@ -455,31 +482,25 @@ impl Message {

/// Try to get a reference to the string text, if this is a Text message.
#[inline]
pub fn to_str(&self) -> Result<&str, Error> {
match self.inner {
protocol::Message::Text(ref s) => Ok(s),
pub fn as_str(&self) -> Result<&str, Error> {
match &self.inner {
protocol::Message::Text(s) => Ok(s.as_str()),
_ => Err(Error::Other("not a text message".into())),
}
}

/// Returns the bytes of this message, if the message can contain data.
#[inline]
pub fn as_bytes(&self) -> &[u8] {
match self.inner {
protocol::Message::Text(ref s) => s.as_bytes(),
protocol::Message::Binary(ref v) => v,
protocol::Message::Ping(ref v) => v,
protocol::Message::Pong(ref v) => v,
match &self.inner {
protocol::Message::Text(s) => s.as_slice(),
protocol::Message::Binary(v) => v.as_slice(),
protocol::Message::Ping(v) => v.as_slice(),
protocol::Message::Pong(v) => v.as_slice(),
protocol::Message::Close(_) => &[],
protocol::Message::Frame(ref v) => v.payload(),
protocol::Message::Frame(v) => v.payload(),
}
}

/// Destructure this message into binary data.
#[inline]
pub fn into_bytes(self) -> Vec<u8> {
self.inner.into_data()
}
}

impl Debug for Message {
Expand All @@ -493,7 +514,7 @@ impl Debug for Message {
impl Into<Vec<u8>> for Message {
#[inline]
fn into(self) -> Vec<u8> {
self.into_bytes()
self.as_bytes().into()
}
}

Expand Down Expand Up @@ -529,15 +550,21 @@ mod tests {
async fn test_websocket() {
let router = Router::new().goal(connect);
let acceptor = TcpListener::new("127.0.0.1:0").bind().await;
let addr = acceptor.holdings()[0].local_addr.clone().into_std().unwrap();
let addr = acceptor.holdings()[0]
.local_addr
.clone()
.into_std()
.unwrap();

tokio::spawn(async move {
Server::new(acceptor).serve(router).await;
});

let stream = tokio::net::TcpStream::connect(addr).await.unwrap();

let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(stream)).await.unwrap();
let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(stream))
.await
.unwrap();
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
Expand Down
9 changes: 4 additions & 5 deletions examples/otel-jaeger/src/server1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ fn init_tracer_provider() -> TracerProvider {
.expect("failed to create exporter");
TracerProvider::builder()
.with_batch_exporter(exporter, runtime::Tokio)
.with_config(
opentelemetry_sdk::trace::Config::default().with_resource(Resource::new(vec![
KeyValue::new("service.name", "server1"),
])),
)
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"server1",
)]))
.build()
}

Expand Down
9 changes: 4 additions & 5 deletions examples/otel-jaeger/src/server2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ fn init_tracer_provider() -> TracerProvider {
.build()
.expect("failed to create exporter");
opentelemetry_sdk::trace::TracerProvider::builder()
.with_config(
opentelemetry_sdk::trace::Config::default().with_resource(Resource::new(vec![
KeyValue::new("service.name", "server2"),
])),
)
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"server2",
)]))
.with_batch_exporter(exporter, runtime::Tokio)
.build()
}
Expand Down
2 changes: 1 addition & 1 deletion examples/websocket-chat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async fn handle_socket(ws: WebSocket) {
tokio::task::spawn(fut);
}
async fn user_message(my_id: usize, msg: Message) {
let msg = if let Ok(s) = msg.to_str() {
let msg = if let Ok(s) = msg.as_str() {
s
} else {
return;
Expand Down

0 comments on commit b7cfd65

Please sign in to comment.