diff --git a/feature-flags/src/lib.rs b/feature-flags/src/lib.rs index 71a5e69..c9d07cc 100644 --- a/feature-flags/src/lib.rs +++ b/feature-flags/src/lib.rs @@ -6,3 +6,11 @@ pub mod server; pub mod v0_endpoint; pub mod v0_request; pub mod team; + +// Test modules don't need to be compiled with main binary +// #[cfg(test)] +// TODO: To use in integration tests, we need to compile with binary +// or make it a separate feature using cfg(feature = "integration-tests") +// and then use this feature only in tests. +// For now, ok to just include in binary +pub mod test_utils; diff --git a/feature-flags/src/team.rs b/feature-flags/src/team.rs index cfa54c3..d55aa93 100644 --- a/feature-flags/src/team.rs +++ b/feature-flags/src/team.rs @@ -11,7 +11,7 @@ use tracing::instrument; // It's from here: https://docs.djangoproject.com/en/4.2/topics/cache/#cache-versioning // F&!£%% on the bright side we don't use this functionality yet. // Will rely on integration tests to catch this. -const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; +pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; // TODO: Check what happens if json has extra stuff, does serde ignore it? Yes // Make sure we don't serialize and store team data in redis. Let main decide endpoint control this... @@ -63,48 +63,13 @@ impl Team { #[cfg(test)] mod tests { - use std::sync::Arc; - use anyhow::Error; - - use crate::redis::RedisClient; - use rand::{distributions::Alphanumeric, Rng}; - + use crate::test_utils::{insert_new_team_in_redis, setup_redis_client}; use super::*; - fn random_string(prefix: &str, length: usize) -> String { - let suffix: String = rand::thread_rng() - .sample_iter(Alphanumeric) - .take(length) - .map(char::from) - .collect(); - format!("{}{}", prefix, suffix) - } - - async fn insert_new_team_in_redis(client: Arc) -> Result { - let id = rand::thread_rng().gen_range(0..10_000_000); - let token = random_string("phc_", 12); - let team = Team { - id: id, - name: "team".to_string(), - api_token: token, - }; - - let serialized_team = serde_json::to_string(&team)?; - client - .set( - format!("{TEAM_TOKEN_CACHE_PREFIX}{}", team.api_token.clone()), - serialized_team, - ) - .await?; - - Ok(team) - } #[tokio::test] async fn test_fetch_team_from_redis() { - let client = RedisClient::new("redis://localhost:6379/".to_string()) - .expect("Failed to create redis client"); - let client = Arc::new(client); + let client = setup_redis_client(None); let team = insert_new_team_in_redis(client.clone()).await.unwrap(); @@ -121,10 +86,12 @@ mod tests { #[tokio::test] async fn test_fetch_invalid_team_from_redis() { - let client = RedisClient::new("redis://localhost:6379/".to_string()) - .expect("Failed to create redis client"); - let client = Arc::new(client); + let client = setup_redis_client(None); + // TODO: It's not ideal that this can fail on random errors like connection refused. + // Is there a way to be more specific throughout this code? + // Or maybe I shouldn't be mapping conn refused to token validation error, and instead handling it as a + // top level 500 error instead of 400 right now. match Team::from_redis(client.clone(), "banana".to_string()).await { Err(FlagError::TokenValidationError) => (), _ => panic!("Expected TokenValidationError"), diff --git a/feature-flags/src/test_utils.rs b/feature-flags/src/test_utils.rs new file mode 100644 index 0000000..1a91c8b --- /dev/null +++ b/feature-flags/src/test_utils.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; +use anyhow::Error; + +use crate::{redis::{Client, RedisClient}, team::{self, Team}}; +use rand::{distributions::Alphanumeric, Rng}; + +pub fn random_string(prefix: &str, length: usize) -> String { + let suffix: String = rand::thread_rng() + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + format!("{}{}", prefix, suffix) +} + +pub async fn insert_new_team_in_redis(client: Arc) -> Result { + let id = rand::thread_rng().gen_range(0..10_000_000); + let token = random_string("phc_", 12); + let team = Team { + id: id, + name: "team".to_string(), + api_token: token, + }; + + let serialized_team = serde_json::to_string(&team)?; + client + .set( + format!("{}{}", team::TEAM_TOKEN_CACHE_PREFIX, team.api_token.clone()), + serialized_team, + ) + .await?; + + Ok(team) +} + +pub fn setup_redis_client(url: Option) -> Arc { + let redis_url = match url { + Some(value) => value, + None => "redis://localhost:6379/".to_string(), + }; + let client = RedisClient::new(redis_url).expect("Failed to create redis client"); + Arc::new(client) +} \ No newline at end of file diff --git a/feature-flags/tests/common.rs b/feature-flags/tests/common/mod.rs similarity index 77% rename from feature-flags/tests/common.rs rename to feature-flags/tests/common/mod.rs index f66a11f..5a63285 100644 --- a/feature-flags/tests/common.rs +++ b/feature-flags/tests/common/mod.rs @@ -4,8 +4,7 @@ use std::string::ToString; use std::sync::Arc; use once_cell::sync::Lazy; -use rand::distributions::Alphanumeric; -use rand::Rng; +use reqwest::header::CONTENT_TYPE; use tokio::net::TcpListener; use tokio::sync::Notify; @@ -44,6 +43,18 @@ impl ServerHandle { client .post(format!("http://{:?}/flags", self.addr)) .body(body) + .header(CONTENT_TYPE, "application/json") + .send() + .await + .expect("failed to send request") + } + + pub async fn send_invalid_header_for_flags_request>(&self, body: T) -> reqwest::Response { + let client = reqwest::Client::new(); + client + .post(format!("http://{:?}/flags", self.addr)) + .body(body) + .header(CONTENT_TYPE, "xyz") .send() .await .expect("failed to send request") @@ -55,12 +66,3 @@ impl Drop for ServerHandle { self.shutdown.notify_one() } } - -pub fn random_string(prefix: &str, length: usize) -> String { - let suffix: String = rand::thread_rng() - .sample_iter(Alphanumeric) - .take(length) - .map(char::from) - .collect(); - format!("{}_{}", prefix, suffix) -} diff --git a/feature-flags/tests/test_flags.rs b/feature-flags/tests/test_flags.rs index 82f41f0..5302ea9 100644 --- a/feature-flags/tests/test_flags.rs +++ b/feature-flags/tests/test_flags.rs @@ -5,14 +5,20 @@ use reqwest::StatusCode; use serde_json::{json, Value}; use crate::common::*; -mod common; + +use feature_flags::test_utils::{insert_new_team_in_redis, setup_redis_client}; + +pub mod common; #[tokio::test] async fn it_sends_flag_request() -> Result<()> { - let token = random_string("token", 16); + let config = DEFAULT_CONFIG.clone(); + let distinct_id = "user_distinct_id".to_string(); - let config = DEFAULT_CONFIG.clone(); + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; let server = ServerHandle::for_config(config).await; @@ -41,3 +47,33 @@ async fn it_sends_flag_request() -> Result<()> { Ok(()) } + + +#[tokio::test] +async fn it_rejects_invalid_headers_flag_request() -> Result<()> { + let config = DEFAULT_CONFIG.clone(); + + let distinct_id = "user_distinct_id".to_string(); + + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; + + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "groups": {"group1": "group1"} + }); + let res = server.send_invalid_header_for_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::BAD_REQUEST, res.status()); + + // We don't want to deserialize the data into a flagResponse struct here, + // because we want to assert the shape of the raw json data. + let response_text = res.text().await?; + + assert_eq!(response_text, "failed to decode request: unsupported content type: xyz"); + + Ok(()) +} \ No newline at end of file