diff --git a/bottlecap/src/bin/bottlecap/main.rs b/bottlecap/src/bin/bottlecap/main.rs index 5c37a17a..561c13e7 100644 --- a/bottlecap/src/bin/bottlecap/main.rs +++ b/bottlecap/src/bin/bottlecap/main.rs @@ -50,7 +50,6 @@ use dogstatsd::{ dogstatsd::{DogStatsD, DogStatsDConfig}, flusher::{build_fqdn_metrics, Flusher as MetricsFlusher}, }; -use lazy_static::lazy_static; use reqwest::Client; use serde::Deserialize; use std::{ @@ -70,11 +69,6 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error}; use tracing_subscriber::EnvFilter; -lazy_static! { - static ref API_KEY_REGEX: regex::Regex = - regex::Regex::new(r"^[a-f0-9]{32}$").expect("Invalid regex for DD API KEY"); -} - #[derive(Clone, Deserialize)] #[serde(rename_all = "camelCase")] struct RegisterResponse { @@ -181,9 +175,7 @@ async fn main() -> Result<()> { .await .map_err(|e| Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?; - if let Some(resolved_api_key) = - clean_api_key(resolve_secrets(Arc::clone(&config), &aws_config).await) - { + if let Some(resolved_api_key) = resolve_secrets(Arc::clone(&config), &aws_config).await { match extension_loop_active(&aws_config, &config, &client, &r, resolved_api_key).await { Ok(()) => { debug!("Extension loop completed successfully"); @@ -202,17 +194,6 @@ async fn main() -> Result<()> { } } -fn clean_api_key(maybe_key: Option) -> Option { - if let Some(key) = maybe_key { - let clean_key = key.trim_end_matches('\n').replace(' ', "").to_string(); - if API_KEY_REGEX.is_match(&clean_key) { - return Some(clean_key); - } - error!("API key has invalid format"); - } - None -} - fn load_configs() -> (AwsConfig, Arc) { // First load the configuration let aws_config = AwsConfig { diff --git a/bottlecap/src/secrets/decrypt.rs b/bottlecap/src/secrets/decrypt.rs index 0a39b0fd..25bf9292 100644 --- a/bottlecap/src/secrets/decrypt.rs +++ b/bottlecap/src/secrets/decrypt.rs @@ -13,38 +13,49 @@ use tracing::debug; use tracing::error; pub async fn resolve_secrets(config: Arc, aws_config: &AwsConfig) -> Option { - if !config.api_key.is_empty() { - debug!("DD_API_KEY found, not trying to resolve secrets"); - Some(config.api_key.clone()) - } else if !config.api_key_secret_arn.is_empty() || !config.kms_api_key.is_empty() { - let before_decrypt = Instant::now(); - - let client = match Client::builder().use_rustls_tls().build() { - Ok(client) => client, - Err(err) => { - error!("Error creating reqwest client: {}", err); - return None; + let api_key_candidate = + if !config.api_key_secret_arn.is_empty() || !config.kms_api_key.is_empty() { + let before_decrypt = Instant::now(); + + let client = match Client::builder().use_rustls_tls().build() { + Ok(client) => client, + Err(err) => { + error!("Error creating reqwest client: {}", err); + return None; + } + }; + + let decrypted_key = if config.kms_api_key.is_empty() { + decrypt_aws_sm(&client, config.api_key_secret_arn.clone(), aws_config).await + } else { + decrypt_aws_kms(&client, config.kms_api_key.clone(), aws_config).await + }; + + debug!("Decrypt took {}ms", before_decrypt.elapsed().as_millis()); + + match decrypted_key { + Ok(key) => Some(key), + Err(err) => { + error!("Error decrypting key: {}", err); + None + } } - }; - - let decrypted_key = if config.api_key_secret_arn.is_empty() { - decrypt_aws_kms(&client, config.kms_api_key.clone(), aws_config).await } else { - decrypt_aws_sm(&client, config.api_key_secret_arn.clone(), aws_config).await + Some(config.api_key.clone()) }; - debug!("Decrypt took {}ms", before_decrypt.elapsed().as_millis()); + clean_api_key(api_key_candidate) +} - match decrypted_key { - Ok(key) => Some(key), - Err(err) => { - error!("Error decrypting key: {}", err); - None - } +fn clean_api_key(maybe_key: Option) -> Option { + if let Some(key) = maybe_key { + let clean_key = key.trim_end_matches('\n').replace(' ', "").to_string(); + if !clean_key.is_empty() { + return Some(clean_key); } - } else { - return None; + error!("API key has invalid format"); } + None } struct RequestArgs<'a> { @@ -253,6 +264,14 @@ mod tests { use super::*; use chrono::{NaiveDateTime, TimeZone}; + #[test] + fn key_cleanup() { + let key = clean_api_key(Some(" 32alxcxf\n".to_string())); + assert_eq!(key.expect("it should parse the key"), "32alxcxf"); + let key = clean_api_key(Some(" \n".to_string())); + assert_eq!(key, None); + } + #[test] #[allow(clippy::unwrap_used)] fn test_build_get_secret_signed_headers() {