diff --git a/src/channel.rs b/src/channel.rs index 6124932..46c8c98 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use tonic_openssl_lnd::lnrpc::{self, channel_point}; +use crate::auth::AuthUser; use crate::{AppState, MAX_SEND_AMOUNT}; #[derive(Clone, Deserialize)] @@ -20,6 +21,7 @@ pub struct ChannelResponse { pub async fn open_channel( state: &AppState, x_forwarded_for: &str, + user: Option<&AuthUser>, payload: ChannelRequest, ) -> anyhow::Result { if payload.capacity > MAX_SEND_AMOUNT.try_into()? { @@ -90,7 +92,7 @@ pub async fn open_channel( state .payments - .add_payment(x_forwarded_for, None, None, payload.capacity as u64) + .add_payment(x_forwarded_for, None, user, payload.capacity as u64) .await; Ok(txid) diff --git a/src/lightning.rs b/src/lightning.rs index 2af7ce8..c6e2d37 100644 --- a/src/lightning.rs +++ b/src/lightning.rs @@ -11,6 +11,7 @@ use nostr::{EventBuilder, Filter, JsonUtil, Kind, Metadata, UncheckedUrl}; use std::str::FromStr; use tonic_openssl_lnd::lnrpc; +use crate::auth::AuthUser; use crate::nostr_dms::RELAYS; use crate::{AppState, MAX_SEND_AMOUNT}; @@ -27,6 +28,7 @@ pub struct LightningResponse { pub async fn pay_lightning( state: &AppState, x_forwarded_for: &str, + user: Option<&AuthUser>, bolt11: &str, ) -> anyhow::Result { let params = PaymentParams::from_str(bolt11).map_err(|_| anyhow::anyhow!("invalid bolt 11"))?; @@ -110,7 +112,7 @@ pub async fn pay_lightning( .add_payment( x_forwarded_for, None, - None, + user, invoice.amount_milli_satoshis().unwrap_or(0) / 1000, ) .await; diff --git a/src/main.rs b/src/main.rs index 36d171a..da01aa8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -87,11 +87,17 @@ async fn main() -> anyhow::Result<()> { "/api/onchain", post(onchain_handler).route_layer(middleware::from_fn(auth_middleware)), ) - .route("/api/lightning", post(lightning_handler)) + .route( + "/api/lightning", + post(lightning_handler).route_layer(middleware::from_fn(auth_middleware)), + ) .route("/api/lnurlw", get(lnurlw_handler)) .route("/api/lnurlw/callback", get(lnurlw_callback_handler)) .route("/api/bolt11", post(bolt11_handler)) - .route("/api/channel", post(channel_handler)) + .route( + "/api/channel", + post(channel_handler).route_layer(middleware::from_fn(auth_middleware)), + ) .fallback(fallback) .layer(Extension(state.clone())) .layer( @@ -274,6 +280,7 @@ async fn onchain_handler( #[axum::debug_handler] async fn lightning_handler( Extension(state): Extension, + Extension(user): Extension, headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { @@ -283,11 +290,15 @@ async fn lightning_handler( .and_then(|x| HeaderValue::to_str(x).ok()) .unwrap_or("Unknown"); - if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 { + if state + .payments + .verify_payments(x_forwarded_for, None, Some(&user)) + .await + { return Err(AppError::new("Too many payments")); } - let payment_hash = pay_lightning(&state, x_forwarded_for, &payload.bolt11).await?; + let payment_hash = pay_lightning(&state, x_forwarded_for, Some(&user), &payload.bolt11).await?; Ok(Json(LightningResponse { payment_hash })) } @@ -329,7 +340,7 @@ async fn lnurlw_callback_handler( return Err(Json(json!({"status": "ERROR", "reason": "Incorrect k1"}))); } - pay_lightning(&state, x_forwarded_for, &payload.pr) + pay_lightning(&state, x_forwarded_for, None, &payload.pr) .await .map_err(|e| Json(json!({"status": "ERROR", "reason": format!("{e}")})))?; Ok(Json(json!({"status": "OK"}))) @@ -351,6 +362,7 @@ async fn bolt11_handler( #[axum::debug_handler] async fn channel_handler( Extension(state): Extension, + Extension(user): Extension, headers: HeaderMap, Json(payload): Json, ) -> Result, AppError> { @@ -360,11 +372,15 @@ async fn channel_handler( .and_then(|x| HeaderValue::to_str(x).ok()) .unwrap_or("Unknown"); - if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 { + if state + .payments + .verify_payments(x_forwarded_for, None, Some(&user)) + .await + { return Err(AppError::new("Too many payments")); } - let txid = open_channel(&state, x_forwarded_for, payload).await?; + let txid = open_channel(&state, x_forwarded_for, Some(&user), payload).await?; Ok(Json(ChannelResponse { txid })) }