diff --git a/Cargo.toml b/Cargo.toml index 2242a39..f77ea53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ log = "0.4" tokio = { version = "1", features = ["sync"] } tracing = { version = "0.1", optional = true } tracing-futures = { version = "0.2", optional = true } +url = "2" [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread", "net", "io-util"] } diff --git a/src/lib.rs b/src/lib.rs index 8e0ea38..0b74f0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,6 +169,7 @@ pub trait Connector { pub struct RoundRobin where Conn: Connector, + SvcSrc: ToString + Clone, { /// Sources used to connect to a service. Usually some form of URL to attempt a connection, /// e.g. `amqp://localhost:5672` @@ -192,7 +193,7 @@ where impl RoundRobin where - SvcSrc: Debug, + SvcSrc: Debug + Display + Clone, E: Next + Display, Conn: Connector, { @@ -264,8 +265,8 @@ where #[cfg(feature = "tracing")] { let span = Span::current(); - span.record("index", &display(index)); - span.record("service", &debug(&self.sources[index])); + span.record("index", display(index)); + span.record("service", debug(self.sources[index].clone())); } // Connect if not already connected diff --git a/tourniquet-celery/Cargo.toml b/tourniquet-celery/Cargo.toml index faacc09..0fb814d 100644 --- a/tourniquet-celery/Cargo.toml +++ b/tourniquet-celery/Cargo.toml @@ -20,6 +20,7 @@ celery = { version = "0.5", default-features = false } log = "0.4" tourniquet = { version = "0.4", path = ".." } tracing = { version = "0.1", optional = true } +url = "2" [dev-dependencies] serde = "1.0" diff --git a/tourniquet-celery/src/lib.rs b/tourniquet-celery/src/lib.rs index 986a871..d1ce553 100644 --- a/tourniquet-celery/src/lib.rs +++ b/tourniquet-celery/src/lib.rs @@ -4,9 +4,9 @@ //! # Example //! //! ```rust,no_run -//! # use celery::task::TaskResult; +//! # use celery::{task, task::TaskResult}; //! # use tourniquet::RoundRobin; -//! # use tourniquet_celery::{CeleryConnector, RoundRobinExt}; +//! # use tourniquet_celery::{CeleryConnector, CelerySource, RoundRobinExt}; //! # //! #[celery::task] //! async fn do_work(work: String) -> TaskResult<()> { @@ -18,16 +18,17 @@ //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let rr = RoundRobin::new( -//! vec!["amqp://rabbit01:5672/".to_owned(), "amqp://rabbit02:5672".to_owned()], +//! vec![CelerySource::from("amqp://rabbit01:5672/".to_owned()), CelerySource::from("amqp://rabbit02:5672".to_owned())], //! CeleryConnector { name: "rr", routes: &[("*", "my_route")], ..Default::default() }, //! ); //! //! # let work = "foo".to_owned(); -//! rr.send_task(|| do_work::new(work.clone())).await.expect("Failed to send task"); +//! rr.send_task(|| do_work(work.clone())).await.expect("Failed to send task"); //! # Ok(()) //! # } //! ``` +use std::borrow::Cow; use std::error::Error; use std::fmt::{Debug, Display, Error as FmtError, Formatter}; @@ -38,6 +39,8 @@ use celery::{ task::{AsyncResult, Signature, Task}, Celery, CeleryBuilder, }; +use url::Url; + use tourniquet::{Connector, Next, RoundRobin}; #[cfg(feature = "trace")] use tracing::{ @@ -92,6 +95,37 @@ impl Error for RRCeleryError { } } +/// Wrapper for String +#[derive(Clone)] +pub struct CelerySource(String); + +impl From for CelerySource { + fn from(src: String) -> Self { + Self(src) + } +} + +fn safe_source(url: &String) -> Cow<'_, String> { + // URL that is safe to log (password stripped) + let Some(mut url_safe): Option = url.parse().ok() else { return Cow::Borrowed(url) }; + if url_safe.password().is_some() { + let _ = url_safe.set_password(Some("********")); + } + Cow::Owned(url_safe.to_string()) +} + +impl Display for CelerySource { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { + Display::fmt(&safe_source(&self.0), f) + } +} + +impl Debug for CelerySource { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { + Debug::fmt(&safe_source(&self.0), f) + } +} + /// Ready to use connector for Celery. /// /// Please refer to @@ -115,10 +149,10 @@ impl<'a> Default for CeleryConnector<'a> { } #[async_trait] -impl<'a> Connector for CeleryConnector<'a> { - #[cfg_attr(feature = "trace", tracing::instrument(skip(self), err))] - async fn connect(&self, url: &String) -> Result { - let mut builder = CeleryBuilder::new(self.name, url.as_ref()); +impl<'a> Connector for CeleryConnector<'a> { + #[cfg_attr(feature = "trace", instrument(skip(self), err))] + async fn connect(&self, src: &CelerySource) -> Result { + let mut builder = CeleryBuilder::new(self.name, src.0.as_ref()); if let Some(queue) = self.default_queue { builder = builder.default_queue(queue); @@ -145,7 +179,7 @@ pub trait RoundRobinExt { #[async_trait] impl RoundRobinExt for RoundRobin where - SvcSrc: Debug + Send + Sync, + SvcSrc: Debug + Send + Sync + Display + Clone, Conn: Connector + Send + Sync, { /// Send a Celery task. @@ -172,11 +206,25 @@ where self.run(|celery| async move { Ok(celery.send_task(task_gen()).await?) }).await?; #[cfg(feature = "trace")] - Span::current().record("task_id", &display(&task.task_id)); + Span::current().record("task_id", display(&task.task_id)); Ok(task) } } /// Shorthand type for a basic RoundRobin type using Celery -pub type CeleryRoundRobin = RoundRobin>; +pub type CeleryRoundRobin = + RoundRobin>; + +#[cfg(test)] +mod tests { + use super::CelerySource; + + #[test] + fn test_display_debug_celery_source_strips_password() { + let source = CelerySource::from("amqp://mylogin:mypassword@rabbitmq.myserver.com/product".to_owned()); + + assert_eq!(format!("{source}"),"amqp://mylogin:********@rabbitmq.myserver.com/product"); + assert_eq!(format!("{source:?}"),"\"amqp://mylogin:********@rabbitmq.myserver.com/product\""); + } +}