Skip to content

Commit

Permalink
Implement a parser for memory type and byte conversion
Browse files Browse the repository at this point in the history
Remove redundant parsing for integer types
  • Loading branch information
dormant-user committed Mar 13, 2024
1 parent 6fd4c19 commit bbcab88
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 35 deletions.
15 changes: 11 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ pub async fn start() -> io::Result<()> {
println!("{}[v{}] - {}", &cargo.pkg_name, &cargo.pkg_version, &cargo.description);
squire::ascii_art::random();

// Log a warning message for max payload size beyond 1 GB
if config.max_payload_size > 1024 * 1024 * 1024 {
// Since the default is just 100 MB, the only way to get here is to have an env var
log::warn!("Max payload size is set to '{}' which exceeds the optimal upload size.",
std::env::var("max_payload_size").unwrap());
log::warn!("Please consider network bandwidth and latency, before using RuStream to upload such high-volume data.");
}

if config.secure_session {
log::warn!(
"Secure session is turned on! This means that the server can ONLY be hosted via HTTPS or localhost"
Expand All @@ -61,14 +69,13 @@ pub async fn start() -> io::Result<()> {
The closure is defining the configuration for the Actix web server.
The purpose of the closure is to configure the server before it starts listening for incoming requests.
*/
let max_payload_size = 10 * 1024 * 1024 * 1024; // 10 GB
let application = move || {
App::new() // Creates a new Actix web application
.app_data(web::Data::new(config_clone.clone()))
.app_data(web::Data::new(jinja.clone()))
.app_data(web::Data::new(fernet.clone()))
.app_data(web::Data::new(session.clone()))
.app_data(web::PayloadConfig::default().limit(max_payload_size))
.app_data(web::PayloadConfig::default().limit(config_clone.max_payload_size))
.wrap(squire::middleware::get_cors(config_clone.websites.clone()))
.wrap(middleware::Logger::default()) // Adds a default logger middleware to the application
.service(routes::basics::health) // Registers a service for handling requests
Expand All @@ -84,8 +91,8 @@ pub async fn start() -> io::Result<()> {
.service(routes::upload::save_files)
};
let server = HttpServer::new(application)
.workers(config.workers as usize)
.max_connections(config.max_connections as usize);
.workers(config.workers)
.max_connections(config.max_connections);
// Reference: https://actix.rs/docs/http2/
if config.cert_file.exists() && config.key_file.exists() {
log::info!("Binding SSL certificate to serve over HTTPS");
Expand Down
2 changes: 1 addition & 1 deletion src/routes/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub async fn login(request: HttpRequest,
let payload = serde_json::to_string(&mapped).unwrap();
let encrypted_payload = fernet.encrypt(payload.as_bytes());

let cookie_duration = Duration::seconds(config.session_duration as i64);
let cookie_duration = Duration::seconds(config.session_duration);
let expiration = OffsetDateTime::now_utc() + cookie_duration;
let base_cookie = Cookie::build("session_token", encrypted_payload)
.http_only(true)
Expand Down
4 changes: 2 additions & 2 deletions src/squire/authenticator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ pub fn verify_token(
username,
};
}
if current_time - timestamp > config.session_duration as i64 {
if current_time - timestamp > config.session_duration {
return AuthToken { ok: false, detail: "Session Expired".to_string(), username };
}
AuthToken {
ok: true,
detail: format!("Session valid for {}s", timestamp + config.session_duration as i64 - current_time),
detail: format!("Session valid for {}s", timestamp + config.session_duration - current_time),
username,
}
} else {
Expand Down
39 changes: 22 additions & 17 deletions src/squire/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ pub struct Config {
/// Host IP address for media streaming.
pub media_host: String,
/// Port number for hosting the application.
pub media_port: i32,
pub media_port: u16,
/// Duration of a session in seconds.
pub session_duration: i32,
pub session_duration: i64,
/// List of supported file formats.
pub file_formats: Vec<String>,

/// Number of worker threads to spin up the server.
pub workers: i32,
pub workers: usize,
/// Maximum number of concurrent connections.
pub max_connections: i32,
pub max_connections: usize,
/// Max payload allowed by the server in request body.
pub max_payload_size: usize,
/// List of websites (supports regex) to add to CORS configuration.
pub websites: Vec<String>,

Expand All @@ -38,13 +40,13 @@ pub struct Config {
pub cert_file: path::PathBuf,
}

/// Returns the default value for debug flag
/// Returns the default value for debug flag.
pub fn default_debug() -> bool { false }

/// Returns the default value for utc_logging
/// Returns the default value for UTC logging.
pub fn default_utc_logging() -> bool { true }

/// Returns the default value for ssl files
/// Returns the default value for SSL files.
pub fn default_ssl() -> path::PathBuf { path::PathBuf::new() }

/// Returns the default media host based on the local machine's IP address.
Expand All @@ -63,33 +65,36 @@ pub fn default_media_host() -> String {
"localhost".to_string()
}

/// Returns the default media port (8000).
pub fn default_media_port() -> i32 { 8000 }
/// Returns the default media port (8000)
pub fn default_media_port() -> u16 { 8000 }

/// Returns the default session duration (3600 seconds).
pub fn default_session_duration() -> i32 { 3600 }
/// Returns the default session duration (3600 seconds)
pub fn default_session_duration() -> i64 { 3600 }

/// Returns the file formats supported by default.
pub fn default_file_formats() -> Vec<String> {
vec!["mp4".to_string(), "mov".to_string(), "jpg".to_string(), "jpeg".to_string()]
}

/// Returns the default number of worker threads (half of logical cores).
pub fn default_workers() -> i32 {
/// Returns the default number of worker threads (half of logical cores)
pub fn default_workers() -> usize {
let logical_cores = thread::available_parallelism();
match logical_cores {
Ok(cores) => cores.get() as i32 / 2,
Ok(cores) => cores.get() / 2,
Err(err) => {
log::error!("{}", err);
3
}
}
}

/// Returns the default maximum number of concurrent connections (3).
pub fn default_max_connections() -> i32 { 3 }
/// Returns the default maximum number of concurrent connections (3)
pub fn default_max_connections() -> usize { 3 }

/// Returns an empty list as the default website (CORS configuration).
/// Returns the default max payload size (100 MB)
pub fn default_max_payload_size() -> usize { 100 * 1024 * 1024 }

/// Returns an empty list as the default website (CORS configuration)
pub fn default_websites() -> Vec<String> { Vec::new() }

/// Returns the default value for secure_session
Expand Down
143 changes: 132 additions & 11 deletions src/squire/startup.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std;
use std::ffi::OsStr;
use std::io::Write;

use chrono::{DateTime, Local, Utc};
Expand Down Expand Up @@ -110,25 +109,75 @@ fn parse_bool(key: &str) -> Option<bool> {
}
}

/// Extracts the env var by key and parses it as a `i32`
/// Extracts the env var by key and parses it as a `i64`
///
/// # Arguments
///
/// * `key` - Key for the environment variable.
///
/// # Returns
///
/// Returns an `Option<i32>` if the value is available.
/// Returns an `Option<i64>` if the value is available.
///
/// # Panics
///
/// If the value is present, but it is an invalid data-type.
fn parse_i32(key: &str) -> Option<i32> {
fn parse_i64(key: &str) -> Option<i64> {
match std::env::var(key) {
Ok(val) => match val.parse() {
Ok(parsed) => Some(parsed),
Err(_) => {
panic!("\n{}\n\texpected i32, received '{}' [value=invalid]\n", key, val);
panic!("\n{}\n\texpected i64, received '{}' [value=invalid]\n", key, val);
}
},
Err(_) => None,
}
}

/// Extracts the env var by key and parses it as a `u16`
///
/// # Arguments
///
/// * `key` - Key for the environment variable.
///
/// # Returns
///
/// Returns an `Option<u16>` if the value is available.
///
/// # Panics
///
/// If the value is present, but it is an invalid data-type.
fn parse_u16(key: &str) -> Option<u16> {
match std::env::var(key) {
Ok(val) => match val.parse() {
Ok(parsed) => Some(parsed),
Err(_) => {
panic!("\n{}\n\texpected u16, received '{}' [value=invalid]\n", key, val);
}
},
Err(_) => None,
}
}

/// Extracts the env var by key and parses it as a `usize`
///
/// # Arguments
///
/// * `key` - Key for the environment variable.
///
/// # Returns
///
/// Returns an `Option<usize>` if the value is available.
///
/// # Panics
///
/// If the value is present, but it is an invalid data-type.
fn parse_usize(key: &str) -> Option<usize> {
match std::env::var(key) {
Ok(val) => match val.parse() {
Ok(parsed) => Some(parsed),
Err(_) => {
panic!("\n{}\n\texpected usize, received '{}' [value=invalid]\n", key, val);
}
},
Err(_) => None,
Expand Down Expand Up @@ -180,6 +229,76 @@ fn parse_path(key: &str) -> Option<std::path::PathBuf> {
}
}

/// Parses the maximum payload size from human-readable memory format to bytes.
///
/// - `key` - Key for the environment variable.
///
/// ## See Also
///
/// - This function handles internal panic gracefully, in the most detailed way possible.
/// - Panic outputs are suppressed with a custom hook.
/// - Custom hook is set before wrapping the potentially panicking function inside `catch_unwind`.
/// - Custom hook is reset later, so the future panics and go uncaught.
/// - Error message from panic payload is also further processed, to get a detailed reason for panic.
///
/// # Returns
///
/// Returns an option of usize if the value is parsable and within the allowed size limit.
fn parse_max_payload(key: &str) -> Option<usize> {
match std::env::var(key) {
Ok(value) => {

let custom_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_panic_info| {}));
let result = std::panic::catch_unwind(|| parse_memory(&value));
std::panic::set_hook(custom_hook);

match result {
Ok(output) => {
if let Some(value) = output {
Some(value)
} else {
panic!("\n{}\n\texpected format: '100 MB', received '{}' [value=invalid]\n",
key, value);
}
}
Err(panic_payload) => {
if let Some(&error) = panic_payload.downcast_ref::<&str>() {
panic!("\n{}\n\t{} [value=invalid]\n", key, error);
} else if let Some(error) = panic_payload.downcast_ref::<String>() {
panic!("\n{}\n\t{} [value=invalid]\n", key, error);
} else if let Some(error) = panic_payload.downcast_ref::<Box<dyn std::fmt::Debug + Send + 'static>>() {
panic!("\n{}\n\t{:?} [value=invalid]\n", key, error);
} else {
panic!("\n{}\n\tinvalid memory format! unable to parse panic payload [value=invalid]\n", key);
}
}
}
}
Err(_) => {
None
}
}
}

fn parse_memory(memory: &str) -> Option<usize> {
let value = memory.trim();
let (size_str, unit) = value.split_at(value.len() - 2);
let size: usize = match size_str.strip_suffix(' ').unwrap_or_default().parse() {
Ok(num) => num,
Err(_) => return None,
};

match unit.to_lowercase().as_str() {
"zb" => Some(size * 1024 * 1024 * 1024 * 1024 * 1024),
"tb" => Some(size * 1024 * 1024 * 1024 * 1024),
"gb" => Some(size * 1024 * 1024 * 1024),
"mb" => Some(size * 1024 * 1024),
"kb" => Some(size * 1024),
_ => None,
}
}

/// Handler that's responsible to parse all the env vars.
///
/// # Returns
Expand All @@ -190,15 +309,16 @@ fn load_env_vars() -> settings::Config {
let debug = parse_bool("debug").unwrap_or(settings::default_debug());
let utc_logging = parse_bool("utc_logging").unwrap_or(settings::default_utc_logging());
let media_host = std::env::var("media_host").unwrap_or(settings::default_media_host());
let media_port = parse_i32("media_port").unwrap_or(settings::default_media_port());
let session_duration = parse_i32("session_duration").unwrap_or(settings::default_session_duration());
let media_port = parse_u16("media_port").unwrap_or(settings::default_media_port());
let session_duration = parse_i64("session_duration").unwrap_or(settings::default_session_duration());
let file_formats = parse_vec("file_formats").unwrap_or(settings::default_file_formats());
let workers = parse_i32("workers").unwrap_or(settings::default_workers());
let max_connections = parse_i32("max_connections").unwrap_or(settings::default_max_connections());
let workers = parse_usize("workers").unwrap_or(settings::default_workers());
let max_connections = parse_usize("max_connections").unwrap_or(settings::default_max_connections());
let websites = parse_vec("websites").unwrap_or(settings::default_websites());
let secure_session = parse_bool("secure_session").unwrap_or(settings::default_secure_session());
let key_file = parse_path("key_file").unwrap_or(settings::default_ssl());
let cert_file = parse_path("cert_file").unwrap_or(settings::default_ssl());
let max_payload_size = parse_max_payload("max_payload_size").unwrap_or(settings::default_max_payload_size());
settings::Config {
authorization,
media_source,
Expand All @@ -210,6 +330,7 @@ fn load_env_vars() -> settings::Config {
file_formats,
workers,
max_connections,
max_payload_size,
websites,
secure_session,
key_file,
Expand Down Expand Up @@ -253,7 +374,7 @@ fn validate_dir_structure(config: &settings::Config, cargo: &Cargo) {
let secure_dir = index_vec.last().unwrap();
// secure_parent_path is the secure index's location
let secure_parent_path = &index_vec[0..index_vec.len() - 1]
.join(OsStr::new(std::path::MAIN_SEPARATOR_STR));
.join(std::ffi::OsStr::new(std::path::MAIN_SEPARATOR_STR));
errors.push_str(&format!(
"\n{:?}\n\tSecure index directory [{:?}] should be at the root [{:?}] [depth={}, valid=1]\n\
\t> Hint: Either move {:?} within {:?}, [OR] set the 'media_source' to {:?}\n",
Expand Down Expand Up @@ -284,7 +405,7 @@ fn validate_dir_structure(config: &settings::Config, cargo: &Cargo) {
get_time(config.utc_logging), cargo.crate_name,
&secure_path.to_str().unwrap())
}
},
}
Err(err) => panic!("{}", err)
}
}
Expand Down

0 comments on commit bbcab88

Please sign in to comment.