diff --git a/backend/src/rate_limit.rs b/backend/src/rate_limit.rs index 57d98fb..23ebcc1 100644 --- a/backend/src/rate_limit.rs +++ b/backend/src/rate_limit.rs @@ -60,6 +60,18 @@ pub struct LoginRateLimitKey { ip: String, } +#[cfg(test)] +const REQUESTS_PER_SECOND: u32 = 1; + +#[cfg(test)] +const REQUESTS_BURST: u32 = 5; + +#[cfg(not(test))] +const REQUESTS_PER_SECOND: u32 = 5; + +#[cfg(not(test))] +const REQUESTS_BURST: u32 = 25; + // RateLimiter for the login route pub struct LoginRateLimiter { rate_limiter: RateLimiter, QuantaClock, governor::middleware::NoOpMiddleware>, @@ -68,7 +80,7 @@ pub struct LoginRateLimiter { impl LoginRateLimiter { pub fn new() -> Self { Self { - rate_limiter: RateLimiter::keyed(Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(5).unwrap())), + rate_limiter: RateLimiter::keyed(Quota::per_second(NonZeroU32::new(REQUESTS_PER_SECOND).unwrap()).allow_burst(NonZeroU32::new(REQUESTS_BURST).unwrap())), } } diff --git a/backend/src/routes/auth/get_login_salt.rs b/backend/src/routes/auth/get_login_salt.rs index b89f236..a9421ec 100644 --- a/backend/src/routes/auth/get_login_salt.rs +++ b/backend/src/routes/auth/get_login_salt.rs @@ -1,6 +1,6 @@ use std::sync::Mutex; -use actix_web::{get, web, HttpResponse, Responder}; +use actix_web::{get, web, HttpRequest, HttpResponse, Responder}; use db_connector::models::users::User; use diesel::{prelude::*, result::Error::NotFound}; use lru::LruCache; @@ -8,9 +8,7 @@ use serde::Deserialize; use utoipa::IntoParams; use crate::{ - error::Error, - utils::{generate_random_bytes, get_connection, web_block_unpacked}, - AppState, + error::Error, rate_limit::LoginRateLimiter, utils::{generate_random_bytes, get_connection, web_block_unpacked}, AppState }; #[derive(Deserialize, IntoParams)] @@ -34,10 +32,14 @@ pub async fn get_login_salt( state: web::Data, query: web::Query, cache: web::Data>>>, + rate_limiter: web::Data, + req: HttpRequest, ) -> actix_web::Result { use db_connector::schema::users::dsl::*; let mail = query.email.to_lowercase(); + rate_limiter.check(mail.clone(), &req)?; + let mut conn = get_connection(&state)?; let salt: Vec = web_block_unpacked(move || { match users