diff --git a/src/channel.rs b/src/channel.rs index 992988c..1057f87 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -17,7 +17,11 @@ pub struct ChannelResponse { pub txid: String, } -pub async fn open_channel(state: AppState, payload: ChannelRequest) -> anyhow::Result { +pub async fn open_channel( + state: AppState, + x_forwarded_for: &str, + payload: ChannelRequest, +) -> anyhow::Result { if payload.capacity > MAX_SEND_AMOUNT.try_into().unwrap() { anyhow::bail!("max capacity is 10,000,000"); } @@ -84,5 +88,10 @@ pub async fn open_channel(state: AppState, payload: ChannelRequest) -> anyhow::R None => anyhow::bail!("failed to open channel"), }; + state + .payments + .add_payment(x_forwarded_for, payload.capacity as u64) + .await; + Ok(txid) } diff --git a/src/lightning.rs b/src/lightning.rs index c0d11b4..27a7f3a 100644 --- a/src/lightning.rs +++ b/src/lightning.rs @@ -23,7 +23,11 @@ pub struct LightningResponse { pub payment_hash: String, } -pub async fn pay_lightning(state: AppState, bolt11: &str) -> anyhow::Result { +pub async fn pay_lightning( + state: AppState, + x_forwarded_for: &str, + bolt11: &str, +) -> anyhow::Result { let params = PaymentParams::from_str(bolt11).map_err(|_| anyhow::anyhow!("invalid bolt 11"))?; let invoice = if let Some(invoice) = params.invoice() { @@ -114,5 +118,13 @@ pub async fn pay_lightning(state: AppState, bolt11: &str) -> anyhow::Result anyhow::Result<()> { #[axum::debug_handler] async fn onchain_handler( Extension(state): Extension, + headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { - let res = pay_onchain(state, payload).await?; + // Extract the X-Forwarded-For header + let x_forwarded_for = headers + .get("x-forwarded-for") + .and_then(|x| HeaderValue::to_str(x).ok()) + .unwrap_or("Unknown"); + + if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 { + return Err(AppError::new("Too many payments")); + } + + let res = pay_onchain(state, x_forwarded_for, payload).await?; Ok(Json(res)) } @@ -148,9 +164,20 @@ async fn onchain_handler( #[axum::debug_handler] async fn lightning_handler( Extension(state): Extension, + headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { - let payment_hash = pay_lightning(state, &payload.bolt11).await?; + // Extract the X-Forwarded-For header + let x_forwarded_for = headers + .get("x-forwarded-for") + .and_then(|x| HeaderValue::to_str(x).ok()) + .unwrap_or("Unknown"); + + if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 { + return Err(AppError::new("Too many payments")); + } + + let payment_hash = pay_lightning(state, x_forwarded_for, &payload.bolt11).await?; Ok(Json(LightningResponse { payment_hash })) } @@ -178,10 +205,21 @@ pub struct LnurlWithdrawParams { #[axum::debug_handler] async fn lnurlw_callback_handler( Extension(state): Extension, + headers: HeaderMap, Query(payload): Query, ) -> Result, Json> { if payload.k1 == "k1" { - pay_lightning(state, &payload.pr) + // Extract the X-Forwarded-For header + let x_forwarded_for = headers + .get("x-forwarded-for") + .and_then(|x| HeaderValue::to_str(x).ok()) + .unwrap_or("Unknown"); + + if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 { + return Err(Json(json!({"status": "ERROR", "reason": "Incorrect k1"}))); + } + + pay_lightning(state, x_forwarded_for, &payload.pr) .await .map_err(|e| Json(json!({"status": "ERROR", "reason": format!("{e}")})))?; Ok(Json(json!({"status": "OK"}))) @@ -203,9 +241,20 @@ async fn bolt11_handler( #[axum::debug_handler] async fn channel_handler( Extension(state): Extension, + headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { - let txid = open_channel(state, payload.clone()).await?; + // Extract the X-Forwarded-For header + let x_forwarded_for = headers + .get("x-forwarded-for") + .and_then(|x| HeaderValue::to_str(x).ok()) + .unwrap_or("Unknown"); + + if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 { + return Err(AppError::new("Too many payments")); + } + + let txid = open_channel(state, x_forwarded_for, payload).await?; Ok(Json(ChannelResponse { txid })) } @@ -213,6 +262,12 @@ async fn channel_handler( // Make our own error that wraps `anyhow::Error`. struct AppError(anyhow::Error); +impl AppError { + fn new(msg: &'static str) -> Self { + AppError(anyhow::anyhow!(msg)) + } +} + // Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> Response { diff --git a/src/onchain.rs b/src/onchain.rs index 14b419f..9296c0a 100644 --- a/src/onchain.rs +++ b/src/onchain.rs @@ -19,6 +19,7 @@ pub struct OnchainResponse { pub async fn pay_onchain( state: AppState, + x_forwarded_for: &str, payload: OnchainRequest, ) -> anyhow::Result { let res = { @@ -61,6 +62,11 @@ pub async fn pay_onchain( wallet_client.send_coins(req).await?.into_inner() }; + state + .payments + .add_payment(x_forwarded_for, amount.to_sat()) + .await; + OnchainResponse { txid: resp.txid, address: address.to_string(), diff --git a/src/payments.rs b/src/payments.rs new file mode 100644 index 0000000..b21ac86 --- /dev/null +++ b/src/payments.rs @@ -0,0 +1,77 @@ +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; + +const CACHE_DURATION: Duration = Duration::from_secs(86_400); // 1 day + +struct Payment { + time: Instant, + amount: u64, +} + +struct PaymentTracker { + payments: VecDeque, +} + +impl PaymentTracker { + pub fn new() -> Self { + PaymentTracker { + payments: VecDeque::new(), + } + } + + pub fn add_payment(&mut self, amount: u64) { + let now = Instant::now(); + let payment = Payment { time: now, amount }; + + self.payments.push_back(payment); + } + + fn clean_old_payments(&mut self) { + let now = Instant::now(); + while let Some(payment) = self.payments.front() { + if now.duration_since(payment.time) < CACHE_DURATION { + break; + } + + self.payments.pop_front(); + } + } + + pub fn sum_payments(&mut self) -> u64 { + self.clean_old_payments(); + self.payments.iter().map(|p| p.amount).sum() + } +} + +#[derive(Clone)] +pub struct PaymentsByIp { + trackers: Arc>>, +} + +impl PaymentsByIp { + pub fn new() -> Self { + PaymentsByIp { + trackers: Arc::new(Mutex::new(HashMap::new())), + } + } + + // Add a payment to the tracker for the given ip + pub async fn add_payment(&self, ip: &str, amount: u64) { + let mut trackers = self.trackers.lock().await; + let tracker = trackers + .entry(ip.to_string()) + .or_insert_with(PaymentTracker::new); + tracker.add_payment(amount); + } + + // Get the total amount of payments for the given ip + pub async fn get_total_payments(&self, ip: &str) -> u64 { + let mut trackers = self.trackers.lock().await; + match trackers.get_mut(ip) { + Some(tracker) => tracker.sum_payments(), + None => 0, + } + } +}