Skip to content

Commit

Permalink
Move some helper functions to util module
Browse files Browse the repository at this point in the history
  • Loading branch information
piegamesde committed Feb 5, 2023
1 parent eab91a5 commit 485bf19
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 44 deletions.
6 changes: 3 additions & 3 deletions src/transfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde_derive::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc;

use super::{core::WormholeError, transit, transit::Transit, AppID, Wormhole};
use super::{core::WormholeError, transit, transit::Transit, util, AppID, Wormhole};
use futures::Future;
use log::*;
use std::{borrow::Cow, path::PathBuf};
Expand Down Expand Up @@ -517,7 +517,7 @@ async fn handle_run_result(
result: Result<(Result<(), TransferError>, impl Future<Output = ()>), crate::util::Cancelled>,
) -> Result<(), TransferError> {
async fn wrap_timeout(run: impl Future<Output = ()>, cancel: impl Future<Output = ()>) {
let run = transit::timeout(SHUTDOWN_TIME, run);
let run = util::timeout(SHUTDOWN_TIME, run);
futures::pin_mut!(run);
match crate::util::cancellable(run, cancel).await {
Ok(Ok(())) => {},
Expand Down Expand Up @@ -573,7 +573,7 @@ async fn handle_run_result(
// and we should not only look for the next one but all have been received
// and we should not interrupt a receive operation without making sure it leaves the connection
// in a consistent state, otherwise the shutdown may cause protocol errors
if let Ok(Ok(Ok(PeerMessage::Error(e)))) = transit::timeout(SHUTDOWN_TIME / 3, wormhole.receive_json()).await {
if let Ok(Ok(Ok(PeerMessage::Error(e)))) = util::timeout(SHUTDOWN_TIME / 3, wormhole.receive_json()).await {
error = TransferError::PeerError(e);
} else {
log::debug!("Failed to retrieve more specific error message from peer. Maybe it crashed?");
Expand Down
48 changes: 7 additions & 41 deletions src/transit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
//! **Notice:** while the resulting TCP connection is naturally bi-directional, the handshake is not symmetric. There *must* be one
//! "leader" side and one "follower" side (formerly called "sender" and "receiver").
use crate::{Key, KeyPurpose};
use crate::{util, Key, KeyPurpose};
use serde_derive::{Deserialize, Serialize};

#[cfg(not(target_family = "wasm"))]
Expand Down Expand Up @@ -691,7 +691,7 @@ pub async fn init(
* so that we will be NATted to the same port again. If it doesn't, simply bind a new socket
* and use that instead.
*/
let socket: MaybeConnectedSocket = match timeout(
let socket: MaybeConnectedSocket = match util::timeout(
std::time::Duration::from_secs(4),
transport::tcp_get_external_ip(),
)
Expand Down Expand Up @@ -874,7 +874,7 @@ impl TransitConnector {
);

let (mut transit, mut finalizer, mut conn_info) =
timeout(std::time::Duration::from_secs(60), connection_stream.next())
util::timeout(std::time::Duration::from_secs(60), connection_stream.next())
.await
.map_err(|_| {
log::debug!("`leader_connect` timed out");
Expand All @@ -896,7 +896,7 @@ impl TransitConnector {
} else {
elapsed.mul_f32(0.3)
};
let _ = timeout(to_wait, async {
let _ = util::timeout(to_wait, async {
while let Some((new_transit, new_finalizer, new_conn_info)) =
connection_stream.next().await
{
Expand Down Expand Up @@ -978,7 +978,7 @@ impl TransitConnector {
}),
);

let transit = match timeout(
let transit = match util::timeout(
std::time::Duration::from_secs(60),
&mut connection_stream.next(),
)
Expand Down Expand Up @@ -1125,7 +1125,7 @@ impl TransitConnector {
.map(move |(i, h)| (i, h, name.clone()))
})
.map(|(index, host, name)| async move {
sleep(std::time::Duration::from_secs(
util::sleep(std::time::Duration::from_secs(
index as u64 * 5,
))
.await;
Expand Down Expand Up @@ -1169,7 +1169,7 @@ impl TransitConnector {
.map(move |(i, u)| (i, u, name.clone()))
})
.map(|(index, url, name)| async move {
sleep(std::time::Duration::from_secs(
util::sleep(std::time::Duration::from_secs(
index as u64 * 5,
))
.await;
Expand Down Expand Up @@ -1369,40 +1369,6 @@ async fn handshake_exchange(
Ok((socket, finalizer))
}

#[cfg(not(target_family = "wasm"))]
pub(super) async fn sleep(duration: std::time::Duration) {
async_std::task::sleep(duration).await
}

#[cfg(target_family = "wasm")]
pub(super) async fn sleep(duration: std::time::Duration) {
/* Skip error handling. Waiting is best effort anyways */
let _ = wasm_timer::Delay::new(duration).await;
}

#[cfg(not(target_family = "wasm"))]
pub(super) async fn timeout<F, T>(
duration: std::time::Duration,
future: F,
) -> Result<T, async_std::future::TimeoutError>
where
F: futures::Future<Output = T>,
{
async_std::future::timeout(duration, future).await
}

#[cfg(target_family = "wasm")]
pub(super) async fn timeout<F, T>(
duration: std::time::Duration,
future: F,
) -> Result<T, std::io::Error>
where
F: futures::Future<Output = T>,
{
use wasm_timer::TryFutureExt;
future.map(Result::Ok).timeout(duration).await
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
33 changes: 33 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl std::fmt::Display for DisplayBytes<'_> {
* TODO remove after https://github.com/quininer/memsec/issues/11 is resolved.
* Original implementation: https://github.com/jedisct1/libsodium/blob/6d566070b48efd2fa099bbe9822914455150aba9/src/libsodium/sodium/utils.c#L262-L307
*/
#[allow(unused)]
pub fn sodium_increment_le(n: &mut [u8]) {
let mut c = 1u16;
for b in n {
Expand Down Expand Up @@ -209,3 +210,35 @@ impl std::fmt::Display for Cancelled {
write!(f, "Task has been cancelled")
}
}

#[cfg(not(target_family = "wasm"))]
pub async fn sleep(duration: std::time::Duration) {
async_std::task::sleep(duration).await
}

#[cfg(target_family = "wasm")]
pub async fn sleep(duration: std::time::Duration) {
/* Skip error handling. Waiting is best effort anyways */
let _ = wasm_timer::Delay::new(duration).await;
}

#[cfg(not(target_family = "wasm"))]
pub async fn timeout<F, T>(
duration: std::time::Duration,
future: F,
) -> Result<T, async_std::future::TimeoutError>
where
F: futures::Future<Output = T>,
{
async_std::future::timeout(duration, future).await
}

#[cfg(target_family = "wasm")]
pub async fn timeout<F, T>(duration: std::time::Duration, future: F) -> Result<T, std::io::Error>
where
F: futures::Future<Output = T>,
{
use futures::FutureExt;
use wasm_timer::TryFutureExt;
future.map(Result::Ok).timeout(duration).await
}

0 comments on commit 485bf19

Please sign in to comment.