diff --git a/Cargo.lock b/Cargo.lock index 9d5823e..c284283 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -231,6 +231,7 @@ dependencies = [ "futures", "hex", "jwt-compact", + "lazy-regex", "log", "mockall", "mockito", @@ -1001,6 +1002,29 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "lazy-regex" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d12be4595afdf58bd19e4a9f4e24187da2a66700786ff660a418e9059937a4c" +dependencies = [ + "lazy-regex-proc_macros", + "once_cell", + "regex", +] + +[[package]] +name = "lazy-regex-proc_macros" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44bcd58e6c97a7fcbaffcdc95728b393b8d98933bfadad49ed4097845b57ef0b" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.50", +] + [[package]] name = "lazy_static" version = "1.4.0" diff --git a/Cargo.toml b/Cargo.toml index e2a45ad..3dd722e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ reqwest = { version = "0.11", features = ["json"] } tokio = { version = "1.12.0", features = ["full"] } tower-http = { version = "0.4.0", features = ["cors"] } async-trait = "0.1.59" +lazy-regex = "3.1.0" [dev-dependencies] mockall = "0.11.2" diff --git a/migrations/2024-02-20-210617_user_info/up.sql b/migrations/2024-02-20-210617_user_info/up.sql index d8615af..4ad8425 100644 --- a/migrations/2024-02-20-210617_user_info/up.sql +++ b/migrations/2024-02-20-210617_user_info/up.sql @@ -1,8 +1,7 @@ CREATE TABLE app_user ( id SERIAL PRIMARY KEY, pubkey VARCHAR(64) NOT NULL, - name VARCHAR(20) NOT NULL, - dm_type VARCHAR(5) NOT NULL, + name VARCHAR(255) NOT NULL UNIQUE, federation_id VARCHAR(64) NOT NULL, federation_invite_code VARCHAR(255) NOT NULL ); diff --git a/src/db.rs b/src/db.rs index fa26908..5fe8fae 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,12 +1,15 @@ -use diesel::{pg::PgConnection, r2d2::ConnectionManager, r2d2::Pool, Connection}; +use diesel::{pg::PgConnection, r2d2::ConnectionManager, r2d2::Pool}; use std::sync::Arc; #[cfg(test)] use mockall::{automock, predicate::*}; +use crate::models::app_user::{AppUser, NewAppUser}; + #[cfg_attr(test, automock)] pub(crate) trait DBConnection { - // fn get_services(&self) -> anyhow::Result>; + fn check_name_available(&self, name: String) -> anyhow::Result; + fn insert_new_user(&self, name: NewAppUser) -> anyhow::Result; } pub(crate) struct PostgresConnection { @@ -14,12 +17,15 @@ pub(crate) struct PostgresConnection { } impl DBConnection for PostgresConnection { - /* - fn get_services(&self) -> anyhow::Result> { + fn check_name_available(&self, name: String) -> anyhow::Result { + let conn = &mut self.db.get()?; + AppUser::check_available_name(conn, name) + } + + fn insert_new_user(&self, new_user: NewAppUser) -> anyhow::Result { let conn = &mut self.db.get()?; - Service::get_services(conn) + new_user.insert(conn) } - */ } pub(crate) fn setup_db(url: String) -> Arc { diff --git a/src/main.rs b/src/main.rs index 6fe0830..7cb2819 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,10 @@ use axum::extract::DefaultBodyLimit; use axum::headers::Origin; use axum::http::{request::Parts, HeaderValue, Method, StatusCode, Uri}; -use axum::routing::{get, post}; +use axum::routing::get; use axum::{http, Extension, Router, TypedHeader}; use log::{error, info}; -use secp256k1::{All, PublicKey, Secp256k1}; -use std::collections::HashMap; +use secp256k1::{All, Secp256k1}; use std::sync::Arc; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::oneshot; @@ -13,10 +12,12 @@ use tower_http::cors::{AllowOrigin, CorsLayer}; use crate::{ db::{setup_db, DBConnection}, - routes::{health_check, valid_origin, validate_cors}, + routes::{check_username, health_check, valid_origin, validate_cors}, }; mod db; +mod models; +mod register; mod routes; const ALLOWED_ORIGINS: [&str; 6] = [ @@ -81,6 +82,7 @@ async fn main() -> anyhow::Result<()> { let server_router = Router::new() .route("/health-check", get(health_check)) + .route("/check-username/:username", get(check_username)) .fallback(fallback) .layer( CorsLayer::new() diff --git a/src/models/app_user.rs b/src/models/app_user.rs new file mode 100644 index 0000000..94fa449 --- /dev/null +++ b/src/models/app_user.rs @@ -0,0 +1,72 @@ +use crate::models::schema::app_user; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive( + QueryableByName, Queryable, AsChangeset, Serialize, Deserialize, Debug, Clone, PartialEq, +)] +#[diesel(check_for_backend(diesel::pg::Pg))] +#[diesel(table_name = app_user)] +pub struct AppUser { + pub id: i32, + pub pubkey: String, + pub name: String, + pub federation_id: String, + pub federation_invite_code: String, +} + +impl AppUser { + pub fn get_app_users(conn: &mut PgConnection) -> anyhow::Result> { + Ok(app_user::table.load::(conn)?) + } + + pub fn get_by_id(conn: &mut PgConnection, user_id: i32) -> anyhow::Result> { + Ok(app_user::table + .filter(app_user::id.eq(user_id)) + .first::(conn) + .optional()?) + } + + pub fn get_by_name(conn: &mut PgConnection, name: String) -> anyhow::Result> { + Ok(app_user::table + .filter(app_user::name.eq(name)) + .first::(conn) + .optional()?) + } + + pub fn check_available_name(conn: &mut PgConnection, name: String) -> anyhow::Result { + Ok(app_user::table + .filter(app_user::name.eq(name)) + .count() + .get_result::(conn)? + == 0) + } + + pub fn get_by_pubkey( + conn: &mut PgConnection, + pubkey: String, + ) -> anyhow::Result> { + Ok(app_user::table + .filter(app_user::pubkey.eq(pubkey)) + .first::(conn) + .optional()?) + } +} + +#[derive(Insertable)] +#[diesel(table_name = app_user)] +pub struct NewAppUser { + pub pubkey: String, + pub name: String, + pub federation_id: String, + pub federation_invite_code: String, +} + +impl NewAppUser { + pub fn insert(&self, conn: &mut PgConnection) -> anyhow::Result { + diesel::insert_into(app_user::table) + .values(self) + .get_result::(conn) + .map_err(|e| e.into()) + } +} diff --git a/src/models/invoice.rs b/src/models/invoice.rs new file mode 100644 index 0000000..410a347 --- /dev/null +++ b/src/models/invoice.rs @@ -0,0 +1,64 @@ +use crate::models::schema::invoice; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive( + QueryableByName, + Queryable, + Insertable, + AsChangeset, + Serialize, + Deserialize, + Debug, + Clone, + PartialEq, +)] +#[diesel(check_for_backend(diesel::pg::Pg))] +#[diesel(table_name = invoice)] +pub struct Invoice { + pub id: i32, + pub federation_id: String, + pub op_id: String, + pub app_user_id: i32, + pub bolt11: String, + pub amount: i64, + pub state: i32, +} + +impl Invoice { + pub fn insert(&self, conn: &mut PgConnection) -> anyhow::Result<()> { + diesel::insert_into(invoice::table) + .values(self) + .execute(conn)?; + + Ok(()) + } + + pub fn get_invoices(conn: &mut PgConnection) -> anyhow::Result> { + Ok(invoice::table.load::(conn)?) + } + + pub fn get_by_id(conn: &mut PgConnection, user_id: i32) -> anyhow::Result> { + Ok(invoice::table + .filter(invoice::id.eq(user_id)) + .first::(conn) + .optional()?) + } + + pub fn get_by_operation( + conn: &mut PgConnection, + op_id: String, + ) -> anyhow::Result> { + Ok(invoice::table + .filter(invoice::op_id.eq(op_id)) + .first::(conn) + .optional()?) + } + + pub fn get_by_state(conn: &mut PgConnection, state: i32) -> anyhow::Result> { + Ok(invoice::table + .filter(invoice::state.eq(state)) + .first::(conn) + .optional()?) + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..02fc1ff --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,4 @@ +pub mod app_user; +pub mod invoice; +mod schema; +pub mod zaps; diff --git a/src/models/schema.rs b/src/models/schema.rs index 539cb3a..4163efc 100644 --- a/src/models/schema.rs +++ b/src/models/schema.rs @@ -5,10 +5,8 @@ diesel::table! { id -> Int4, #[max_length = 64] pubkey -> Varchar, - #[max_length = 20] + #[max_length = 255] name -> Varchar, - #[max_length = 5] - dm_type -> Varchar, #[max_length = 64] federation_id -> Varchar, #[max_length = 255] diff --git a/src/models/zaps.rs b/src/models/zaps.rs new file mode 100644 index 0000000..c2e7d3a --- /dev/null +++ b/src/models/zaps.rs @@ -0,0 +1,52 @@ +use crate::models::schema::zaps; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive( + QueryableByName, + Queryable, + Insertable, + AsChangeset, + Serialize, + Deserialize, + Debug, + Clone, + PartialEq, +)] +#[diesel(check_for_backend(diesel::pg::Pg))] +#[diesel(table_name = zaps)] +pub struct Zap { + pub id: i32, + pub request: String, + pub event_id: Option, +} + +impl Zap { + pub fn insert(&self, conn: &mut PgConnection) -> anyhow::Result<()> { + diesel::insert_into(zaps::table) + .values(self) + .execute(conn)?; + + Ok(()) + } + + pub fn get_zaps(conn: &mut PgConnection) -> anyhow::Result> { + Ok(zaps::table.load::(conn)?) + } + + pub fn get_by_id(conn: &mut PgConnection, zap_id: i32) -> anyhow::Result> { + Ok(zaps::table + .filter(zaps::id.eq(zap_id)) + .first::(conn) + .optional()?) + } + + pub fn set_event_id(&self, conn: &mut PgConnection, event_id: String) -> anyhow::Result<()> { + diesel::update(zaps::table) + .filter(zaps::id.eq(self.id)) + .set(zaps::event_id.eq(event_id)) + .execute(conn)?; + + Ok(()) + } +} diff --git a/src/register.rs b/src/register.rs new file mode 100644 index 0000000..1c16f7f --- /dev/null +++ b/src/register.rs @@ -0,0 +1,79 @@ +use crate::State; +use lazy_regex::*; + +pub static ALPHANUMERIC_REGEX: Lazy = lazy_regex!("^[a-zA-Z0-9]+$"); + +pub fn is_valid_name(name: &str) -> bool { + if name.len() > 30 { + return false; + } + + ALPHANUMERIC_REGEX.is_match(name) +} + +pub async fn check_available(state: &State, name: String) -> anyhow::Result { + if !is_valid_name(&name) { + return Ok(false); + } + + state.db.check_name_available(name) +} + +#[cfg(all(test, not(feature = "integration-tests")))] +mod tests { + use crate::register::is_valid_name; + + #[tokio::test] + async fn check_name() { + // bad names + assert!(!is_valid_name("thisisoverthe30characternamelimit")); + assert!(!is_valid_name("thisisoverthe30characternamelimit")); + assert!(!is_valid_name("no!")); + assert!(!is_valid_name("bad&name")); + assert!(!is_valid_name("bad space name")); + assert!(!is_valid_name("bad_name")); + + // good + assert!(is_valid_name("goodname")); + assert!(is_valid_name("goodname1")); + assert!(is_valid_name("yesnameisverygoodandunderlimit")); + } +} + +#[cfg(all(test, feature = "integration-tests"))] +mod tests_integration { + use secp256k1::Secp256k1; + + use crate::{db::setup_db, models::app_user::NewAppUser, register::check_available, State}; + + #[tokio::test] + async fn test_username_checker() { + dotenv::dotenv().ok(); + let pg_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + let db = setup_db(pg_url); + let state = State { + db: db.clone(), + secp: Secp256k1::new(), + }; + + let name = "veryuniquename123".to_string(); + let available = check_available(&state, name).await.expect("should get"); + assert!(available); + + let commonname = "commonname".to_string(); + let common_app_user = NewAppUser { + pubkey: "".to_string(), + name: commonname.clone(), + federation_id: "".to_string(), + federation_invite_code: "".to_string(), + }; + + // don't care about error if already exists + let _ = state.db.insert_new_user(common_app_user); + + let available = check_available(&state, commonname) + .await + .expect("should get"); + assert!(!available); + } +} diff --git a/src/routes.rs b/src/routes.rs index 480fd15..8920f79 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,12 +1,28 @@ -use crate::{ALLOWED_LOCALHOST, ALLOWED_ORIGINS, ALLOWED_SUBDOMAIN, API_VERSION}; -use axum::headers::authorization::Bearer; -use axum::headers::{Authorization, Origin}; +use crate::{ + register::check_available, State, ALLOWED_LOCALHOST, ALLOWED_ORIGINS, ALLOWED_SUBDOMAIN, + API_VERSION, +}; +use axum::extract::Path; +use axum::headers::Origin; use axum::http::StatusCode; use axum::Extension; use axum::{Json, TypedHeader}; use log::{debug, error}; -use serde::{Deserialize, Serialize}; -use tbs::{BlindedMessage, BlindedSignature}; +use serde::Serialize; + +pub async fn check_username( + origin: Option>, + Extension(state): Extension, + Path(username): Path, +) -> Result, (StatusCode, String)> { + debug!("check_username: {}", username); + validate_cors(origin)?; + + match check_available(&state, username).await { + Ok(res) => Ok(Json(res)), + Err(e) => Err(handle_anyhow_error("check_username", e)), + } +} #[derive(Serialize)] pub struct HealthResponse {