diff --git a/crates/local-cluster-runner/src/node/mod.rs b/crates/local-cluster-runner/src/node/mod.rs index aaf4f9ef9..4debc6640 100644 --- a/crates/local-cluster-runner/src/node/mod.rs +++ b/crates/local-cluster-runner/src/node/mod.rs @@ -47,7 +47,7 @@ pub struct Node { #[mutator(requires = [base_dir])] pub fn with_node_socket(self) { let node_socket: PathBuf = PathBuf::from(self.base_config.node_name()).join("node.sock"); - self.base_config.common.bind_address = BindAddress::Uds(node_socket.clone()); + self.base_config.common.bind_address = Some(BindAddress::Uds(node_socket.clone())); self.base_config.common.advertised_address = AdvertisedAddress::Uds(node_socket); } @@ -207,8 +207,19 @@ impl Node { if let BindAddress::Uds(file) = &mut self.base_config.metadata_store.bind_address { *file = base_dir.join(&*file) } - if let BindAddress::Uds(file) = &mut self.base_config.common.bind_address { - *file = base_dir.join(&*file) + + if self.base_config.common.bind_address.is_none() { + // Derive bind_address from advertised_address + self.base_config.common.bind_address = Some( + self.base_config + .common + .advertised_address + .derive_bind_address(), + ); + } + + if let Some(BindAddress::Uds(file)) = &mut self.base_config.common.bind_address { + *file = base_dir.join(&*file); } if let AdvertisedAddress::Uds(file) = &mut self.base_config.common.advertised_address { *file = base_dir.join(&*file) diff --git a/crates/node/src/network_server/service.rs b/crates/node/src/network_server/service.rs index ef55cf536..5ccd9bdd9 100644 --- a/crates/node/src/network_server/service.rs +++ b/crates/node/src/network_server/service.rs @@ -63,7 +63,7 @@ impl NetworkServer { ); server_builder - .run(node_health, axum_router, &options.bind_address) + .run(node_health, axum_router, &options.bind_address.unwrap()) .await?; Ok(()) diff --git a/crates/types/src/config/common.rs b/crates/types/src/config/common.rs index 646aa5423..221fb4746 100644 --- a/crates/types/src/config/common.rs +++ b/crates/types/src/config/common.rs @@ -65,9 +65,11 @@ pub struct CommonOptions { #[serde(flatten)] pub metadata_store_client: MetadataStoreClientOptions, - /// Address to bind for the Node server. Default is `0.0.0.0:5122` + /// Address to bind for the Node server. Derived from the advertised address, defaulting + /// to `0.0.0.0:$PORT` (where the port will be inferred from the URL scheme). + #[serde(default, skip_serializing_if = "Option::is_none")] #[cfg_attr(feature = "schemars", schemars(with = "String"))] - pub bind_address: BindAddress, + pub bind_address: Option, /// Address that other nodes will use to connect to this node. Default is `http://127.0.0.1:5122/` #[cfg_attr(feature = "schemars", schemars(with = "String"))] @@ -316,6 +318,14 @@ impl CommonOptions { .expect("number of cpu cores fits in u32"), ) } + + /// set derived values if they are not configured to reduce verbose configurations + pub fn set_derived_values(&mut self) { + // Only derive bind_address if it is not explicitly set + if self.bind_address.is_none() { + self.bind_address = Some(self.advertised_address.derive_bind_address()); + } + } } impl Default for CommonOptions { @@ -337,7 +347,7 @@ impl Default for CommonOptions { allow_bootstrap: true, base_dir: None, metadata_store_client: MetadataStoreClientOptions::default(), - bind_address: "0.0.0.0:5122".parse().unwrap(), + bind_address: None, advertised_address: AdvertisedAddress::from_str("http://127.0.0.1:5122/").unwrap(), bootstrap_num_partitions: NonZeroU16::new(24).unwrap(), histogram_inactivity_timeout: None, diff --git a/crates/types/src/config_loader.rs b/crates/types/src/config_loader.rs index 7dff6d471..f063c8cfe 100644 --- a/crates/types/src/config_loader.rs +++ b/crates/types/src/config_loader.rs @@ -60,7 +60,9 @@ impl ConfigLoader { figment = figment.merge(Figment::from(Serialized::defaults(cli_overrides))) } - let config: Configuration = figment.extract()?; + let mut config: Configuration = figment.extract()?; + + config.common.set_derived_values(); Ok(config.apply_cascading_values()) } diff --git a/crates/types/src/net/mod.rs b/crates/types/src/net/mod.rs index 65cde116a..9a4e1c9a6 100644 --- a/crates/types/src/net/mod.rs +++ b/crates/types/src/net/mod.rs @@ -18,10 +18,11 @@ pub mod partition_processor_manager; pub mod remote_query_scanner; pub mod replicated_loglet; +use anyhow::{Context, Error}; // re-exports for convenience pub use error::*; -use std::net::{AddrParseError, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::path::PathBuf; use std::str::FromStr; @@ -48,22 +49,61 @@ pub enum AdvertisedAddress { /// Unix domain socket #[display("unix:{}", _0.display())] Uds(PathBuf), - /// Hostname or host:port pair, or any unrecognizable string. + /// Hostname or host:port pair #[display("{}", _0)] Http(Uri), } impl FromStr for AdvertisedAddress { - type Err = http::uri::InvalidUri; + type Err = Error; fn from_str(s: &str) -> Result { - if let Some(stripped_address) = s.strip_prefix("unix:") { - Ok(AdvertisedAddress::Uds( - stripped_address.parse().expect("infallible"), - )) - } else { - // try to parse as a URI - Ok(AdvertisedAddress::Http(s.parse()?)) + match s.trim() { + "" => Err(anyhow::anyhow!("Advertised address cannot be empty")), + address if address.starts_with("unix:") => { + parse_uds(&address[5..]).map(AdvertisedAddress::Uds) + } + address => parse_http(address), + } + } +} + +fn parse_uds(s: &str) -> Result { + s.parse::() + .with_context(|| format!("Failed to parse Unix domain socket path: '{}'", s)) +} + +fn parse_http(s: &str) -> Result { + let uri = s + .parse::() + .with_context(|| format!("Invalid URI format: '{}'", s))?; + + match uri.scheme_str() { + Some("http") | Some("https") => Ok(AdvertisedAddress::Http(uri)), + Some(other) => Err(anyhow::anyhow!("Unsupported URI scheme '{}'", other)), + None => Err(anyhow::anyhow!("Missing URI scheme in: '{}'", s)), + } +} + +impl AdvertisedAddress { + /// Derives a `BindAddress` based on the advertised address + pub fn derive_bind_address(&self) -> BindAddress { + match self { + AdvertisedAddress::Http(uri) => { + let port = uri + .authority() + .and_then(|auth| auth.port_u16()) + .unwrap_or(80); // HTTP default port is 80 if unspecified + + let ip = if uri.host().unwrap_or("").contains(':') { + IpAddr::V6(Ipv6Addr::UNSPECIFIED) + } else { + IpAddr::V4(Ipv4Addr::UNSPECIFIED) + }; + + BindAddress::Socket(SocketAddr::new(ip, port)) + } + AdvertisedAddress::Uds(path) => BindAddress::Uds(path.clone()), } } } @@ -82,26 +122,24 @@ pub enum BindAddress { /// Unix domain socket #[display("unix:{}", _0.display())] Uds(PathBuf), - /// Socket addr. + /// Socket address (IP and port). #[display("{}", _0)] Socket(SocketAddr), } impl FromStr for BindAddress { - type Err = AddrParseError; + type Err = Error; fn from_str(s: &str) -> Result { - if let Some(stripped_address) = s.strip_prefix("unix:") { - Ok(BindAddress::Uds( - stripped_address.parse().expect("infallible"), - )) - } else { - // try to parse as a URI - Ok(BindAddress::Socket(s.parse()?)) + match s.strip_prefix("unix:") { + Some(path) => parse_uds(path).map(BindAddress::Uds), + None => s + .parse::() + .map(BindAddress::Socket) + .map_err(|e| Error::new(e).context("Failed to parse socket address")), } } } - pub trait RpcRequest: Targeted { type ResponseMessage: Targeted + WireEncode; } @@ -198,31 +236,196 @@ use {define_message, define_rpc}; #[cfg(test)] mod tests { - use http::Uri; - use super::*; - // test parsing [`AdvertisedAddress`] + use std::{net::Ipv6Addr, str::FromStr}; + + #[test] + fn test_parse_empty_input() { + let result = AdvertisedAddress::from_str(""); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Advertised address cannot be empty" + ); + } + + #[test] + fn test_parse_valid_uds() { + let input = "unix:/path/to/socket"; + let addr = input + .parse::() + .expect("Failed to parse UDS address"); + + match addr { + AdvertisedAddress::Uds(path) => assert_eq!(path, PathBuf::from("/path/to/socket")), + _ => panic!("Expected Uds variant"), + } + } + + #[test] + fn test_parse_valid_http_uri() { + let result = AdvertisedAddress::from_str("http://localhost:8080"); + assert!(result.is_ok()); + if let AdvertisedAddress::Http(uri) = result.unwrap() { + assert_eq!(uri.to_string(), "http://localhost:8080/"); + } else { + panic!("Expected Http variant"); + } + } + #[test] - fn test_parse_network_address() -> anyhow::Result<()> { - let tcp: AdvertisedAddress = "127.0.0.1:5123".parse()?; - restate_test_util::assert_eq!(tcp, AdvertisedAddress::Http("127.0.0.1:5123".parse()?)); - - let tcp: AdvertisedAddress = "unix:/tmp/unix.socket".parse()?; - restate_test_util::assert_eq!( - tcp, - AdvertisedAddress::Uds("/tmp/unix.socket".parse().unwrap()) + fn test_parse_missing_scheme() { + let result = AdvertisedAddress::from_str("localhost:8080"); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Missing URI scheme in: 'localhost:8080'" ); + } + + #[test] + fn test_parse_unsupported_scheme() { + let result = AdvertisedAddress::from_str("ftp://localhost"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Unsupported URI scheme 'ftp'")); + } - let tcp: AdvertisedAddress = "localhost:5123".parse()?; - restate_test_util::assert_eq!(tcp, AdvertisedAddress::Http("localhost:5123".parse()?)); + #[test] + fn test_parse_invalid_address() { + let input = ""; + let result = input.parse::(); + assert!(result.is_err(), "Expected an error for empty input"); - let tcp: AdvertisedAddress = "https://localhost:5123".parse()?; - restate_test_util::assert_eq!( - tcp, - AdvertisedAddress::Http(Uri::from_static("https://localhost:5123")) + let input = "ftp://localhost:8080"; + let result = input.parse::(); + assert!( + result.is_err(), + "Expected an error for unsupported URI scheme" ); + } + + #[test] + fn test_derive_bind_address_http() { + let input = "http://localhost:8080"; + let addr = input + .parse::() + .expect("Failed to parse HTTP address"); + let bind_addr = addr.derive_bind_address(); + + match bind_addr { + BindAddress::Socket(socket_addr) => { + assert_eq!( + socket_addr, + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 8080) + ) + } + _ => panic!("Expected Socket variant"), + } + } + + #[test] + fn test_derive_bind_address_fallback_port() { + // Case with no port specified, should fallback to 80 + let advertised_address = AdvertisedAddress::from_str("http://example.com").unwrap(); + let bind_address = advertised_address.derive_bind_address(); + + match bind_address { + BindAddress::Socket(socket_addr) => { + assert_eq!(socket_addr.port(), 80, "Expected port 80 for fallback"); + assert_eq!( + socket_addr.ip(), + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + "Expected IPv4 unspecified address" + ); + } + _ => panic!("Expected BindAddress::Socket"), + } + } + + #[test] + fn test_derive_bind_address_uds() { + let input = "unix:/path/to/socket"; + let addr = input + .parse::() + .expect("Failed to parse UDS address"); + let bind_addr = addr.derive_bind_address(); + + match bind_addr { + BindAddress::Uds(path) => assert_eq!(path, PathBuf::from("/path/to/socket")), + _ => panic!("Expected Uds variant"), + } + } + #[test] + fn test_derive_bind_address_ipv6() { + // Create an IPv6 advertised address + let address = "http://[::1]:8080"; // IPv6 localhost with port 8080 + let advertised_address = + AdvertisedAddress::from_str(address).expect("Failed to parse IPv6 URI"); + + // Derive the bind address + let bind_address = advertised_address.derive_bind_address(); + + // Check that it matches the expected IPv6 bind address + match bind_address { + BindAddress::Socket(socket_addr) => { + assert_eq!(socket_addr.ip(), IpAddr::V6(Ipv6Addr::UNSPECIFIED)); + assert_eq!(socket_addr.port(), 8080); + } + _ => panic!( + "Expected BindAddress::Socket with IPv6, got {:?}", + bind_address + ), + } + } + + #[test] + fn test_parse_bind_address_socket() { + let input = "127.0.0.1:8080"; + let addr = input + .parse::() + .expect("Failed to parse Socket address"); + + match addr { + BindAddress::Socket(socket_addr) => { + assert_eq!(socket_addr, "127.0.0.1:8080".parse().unwrap()) + } + _ => panic!("Expected Socket variant"), + } + } + + #[test] + fn test_parse_bind_address_uds() { + let input = "unix:/path/to/socket"; + let addr = input + .parse::() + .expect("Failed to parse UDS address"); + + match addr { + BindAddress::Uds(path) => assert_eq!(path, PathBuf::from("/path/to/socket")), + _ => panic!("Expected Uds variant"), + } + } + + #[test] + fn test_parse_bind_address_invalid() { + let input = "unsupported:address"; + let result = input.parse::(); + assert!( + result.is_err(), + "Expected an error for invalid bind address" + ); + } + + #[test] + fn test_invalid_advertised_address() { + // Test case for an invalid AdvertisedAddress string + let result = AdvertisedAddress::from_str("invalid-address"); - Ok(()) + // Parsing should fail, resulting in an error + assert!(result.is_err(), "Expected an error for invalid address"); } } diff --git a/server/tests/common/replicated_loglet.rs b/server/tests/common/replicated_loglet.rs index 92fcea734..c1aece5f2 100644 --- a/server/tests/common/replicated_loglet.rs +++ b/server/tests/common/replicated_loglet.rs @@ -51,7 +51,7 @@ async fn replicated_loglet_client( .common .set_cluster_name(cluster.cluster_name().to_owned()); config.common.advertised_address = AdvertisedAddress::Uds(node_socket.clone()); - config.common.bind_address = BindAddress::Uds(node_socket.clone()); + config.common.bind_address = Some(BindAddress::Uds(node_socket.clone())); config.common.metadata_store_client = cluster.nodes[0] .config() .common