Skip to content

Commit

Permalink
Add IP address based rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
benthecarman committed Jun 12, 2024
1 parent 3d1718a commit 9ac993e
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 6 deletions.
11 changes: 10 additions & 1 deletion src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ pub struct ChannelResponse {
pub txid: String,
}

pub async fn open_channel(state: AppState, payload: ChannelRequest) -> anyhow::Result<String> {
pub async fn open_channel(
state: AppState,
x_forwarded_for: &str,
payload: ChannelRequest,
) -> anyhow::Result<String> {
if payload.capacity > MAX_SEND_AMOUNT.try_into().unwrap() {
anyhow::bail!("max capacity is 10,000,000");
}
Expand Down Expand Up @@ -84,5 +88,10 @@ pub async fn open_channel(state: AppState, payload: ChannelRequest) -> anyhow::R
None => anyhow::bail!("failed to open channel"),
};

state
.payments
.add_payment(x_forwarded_for, payload.capacity as u64)
.await;

Ok(txid)
}
14 changes: 13 additions & 1 deletion src/lightning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ pub struct LightningResponse {
pub payment_hash: String,
}

pub async fn pay_lightning(state: AppState, bolt11: &str) -> anyhow::Result<String> {
pub async fn pay_lightning(
state: AppState,
x_forwarded_for: &str,
bolt11: &str,
) -> anyhow::Result<String> {
let params = PaymentParams::from_str(bolt11).map_err(|_| anyhow::anyhow!("invalid bolt 11"))?;

let invoice = if let Some(invoice) = params.invoice() {
Expand Down Expand Up @@ -114,5 +118,13 @@ pub async fn pay_lightning(state: AppState, bolt11: &str) -> anyhow::Result<Stri
response.payment_preimage
};

state
.payments
.add_payment(
x_forwarded_for,
invoice.amount_milli_satoshis().unwrap_or(0),
)
.await;

Ok(hex::encode(payment_preimage))
}
63 changes: 59 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use axum::extract::Query;
use axum::headers::{HeaderMap, HeaderValue};
use axum::http::Uri;
use axum::{
http::StatusCode,
Expand All @@ -19,6 +20,7 @@ use tonic_openssl_lnd::LndLightningClient;
use tower_http::cors::{AllowHeaders, AllowMethods, Any, CorsLayer};

use crate::nostr_dms::listen_to_nostr_dms;
use crate::payments::PaymentsByIp;
use bolt11::{request_bolt11, Bolt11Request, Bolt11Response};
use channel::{open_channel, ChannelRequest, ChannelResponse};
use lightning::{pay_lightning, LightningRequest, LightningResponse};
Expand All @@ -30,6 +32,7 @@ mod channel;
mod lightning;
mod nostr_dms;
mod onchain;
mod payments;
mod setup;

#[derive(Clone)]
Expand All @@ -39,6 +42,7 @@ pub struct AppState {
network: bitcoin::Network,
lightning_client: LndLightningClient,
lnurl: AsyncClient,
payments: PaymentsByIp,
}

impl AppState {
Expand All @@ -55,6 +59,7 @@ impl AppState {
network,
lightning_client,
lnurl,
payments: PaymentsByIp::new(),
}
}
}
Expand Down Expand Up @@ -138,19 +143,41 @@ async fn main() -> anyhow::Result<()> {
#[axum::debug_handler]
async fn onchain_handler(
Extension(state): Extension<AppState>,
headers: HeaderMap,
Json(payload): Json<OnchainRequest>,
) -> Result<Json<OnchainResponse>, AppError> {
let res = pay_onchain(state, payload).await?;
// Extract the X-Forwarded-For header
let x_forwarded_for = headers
.get("x-forwarded-for")
.and_then(|x| HeaderValue::to_str(x).ok())
.unwrap_or("Unknown");

if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
return Err(AppError::new("Too many payments"));
}

let res = pay_onchain(state, x_forwarded_for, payload).await?;

Ok(Json(res))
}

#[axum::debug_handler]
async fn lightning_handler(
Extension(state): Extension<AppState>,
headers: HeaderMap,
Json(payload): Json<LightningRequest>,
) -> Result<Json<LightningResponse>, AppError> {
let payment_hash = pay_lightning(state, &payload.bolt11).await?;
// Extract the X-Forwarded-For header
let x_forwarded_for = headers
.get("x-forwarded-for")
.and_then(|x| HeaderValue::to_str(x).ok())
.unwrap_or("Unknown");

if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
return Err(AppError::new("Too many payments"));
}

let payment_hash = pay_lightning(state, x_forwarded_for, &payload.bolt11).await?;

Ok(Json(LightningResponse { payment_hash }))
}
Expand Down Expand Up @@ -178,10 +205,21 @@ pub struct LnurlWithdrawParams {
#[axum::debug_handler]
async fn lnurlw_callback_handler(
Extension(state): Extension<AppState>,
headers: HeaderMap,
Query(payload): Query<LnurlWithdrawParams>,
) -> Result<Json<Value>, Json<Value>> {
if payload.k1 == "k1" {
pay_lightning(state, &payload.pr)
// Extract the X-Forwarded-For header
let x_forwarded_for = headers
.get("x-forwarded-for")
.and_then(|x| HeaderValue::to_str(x).ok())
.unwrap_or("Unknown");

if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
return Err(Json(json!({"status": "ERROR", "reason": "Incorrect k1"})));
}

pay_lightning(state, x_forwarded_for, &payload.pr)
.await
.map_err(|e| Json(json!({"status": "ERROR", "reason": format!("{e}")})))?;
Ok(Json(json!({"status": "OK"})))
Expand All @@ -203,16 +241,33 @@ async fn bolt11_handler(
#[axum::debug_handler]
async fn channel_handler(
Extension(state): Extension<AppState>,
headers: HeaderMap,
Json(payload): Json<ChannelRequest>,
) -> Result<Json<ChannelResponse>, AppError> {
let txid = open_channel(state, payload.clone()).await?;
// Extract the X-Forwarded-For header
let x_forwarded_for = headers
.get("x-forwarded-for")
.and_then(|x| HeaderValue::to_str(x).ok())
.unwrap_or("Unknown");

if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
return Err(AppError::new("Too many payments"));
}

let txid = open_channel(state, x_forwarded_for, payload).await?;

Ok(Json(ChannelResponse { txid }))
}

// Make our own error that wraps `anyhow::Error`.
struct AppError(anyhow::Error);

impl AppError {
fn new(msg: &'static str) -> Self {
AppError(anyhow::anyhow!(msg))
}
}

// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
Expand Down
6 changes: 6 additions & 0 deletions src/onchain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct OnchainResponse {

pub async fn pay_onchain(
state: AppState,
x_forwarded_for: &str,
payload: OnchainRequest,
) -> anyhow::Result<OnchainResponse> {
let res = {
Expand Down Expand Up @@ -61,6 +62,11 @@ pub async fn pay_onchain(
wallet_client.send_coins(req).await?.into_inner()
};

state
.payments
.add_payment(x_forwarded_for, amount.to_sat())
.await;

OnchainResponse {
txid: resp.txid,
address: address.to_string(),
Expand Down
77 changes: 77 additions & 0 deletions src/payments.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;

const CACHE_DURATION: Duration = Duration::from_secs(86_400); // 1 day

struct Payment {
time: Instant,
amount: u64,
}

struct PaymentTracker {
payments: VecDeque<Payment>,
}

impl PaymentTracker {
pub fn new() -> Self {
PaymentTracker {
payments: VecDeque::new(),
}
}

pub fn add_payment(&mut self, amount: u64) {
let now = Instant::now();
let payment = Payment { time: now, amount };

self.payments.push_back(payment);
}

fn clean_old_payments(&mut self) {
let now = Instant::now();
while let Some(payment) = self.payments.front() {
if now.duration_since(payment.time) < CACHE_DURATION {
break;
}

self.payments.pop_front();
}
}

pub fn sum_payments(&mut self) -> u64 {
self.clean_old_payments();
self.payments.iter().map(|p| p.amount).sum()
}
}

#[derive(Clone)]
pub struct PaymentsByIp {
trackers: Arc<Mutex<HashMap<String, PaymentTracker>>>,
}

impl PaymentsByIp {
pub fn new() -> Self {
PaymentsByIp {
trackers: Arc::new(Mutex::new(HashMap::new())),
}
}

// Add a payment to the tracker for the given ip
pub async fn add_payment(&self, ip: &str, amount: u64) {
let mut trackers = self.trackers.lock().await;
let tracker = trackers
.entry(ip.to_string())
.or_insert_with(PaymentTracker::new);
tracker.add_payment(amount);
}

// Get the total amount of payments for the given ip
pub async fn get_total_payments(&self, ip: &str) -> u64 {
let mut trackers = self.trackers.lock().await;
match trackers.get_mut(ip) {
Some(tracker) => tracker.sum_payments(),
None => 0,
}
}
}

0 comments on commit 9ac993e

Please sign in to comment.