diff --git a/src/db.rs b/src/db.rs index 965117a..40457f1 100644 --- a/src/db.rs +++ b/src/db.rs @@ -754,7 +754,13 @@ impl Database { pub async fn read_txs( &self, relayer_id: &str, + tx_status_filter: Option>, ) -> eyre::Result> { + let (should_filter, status_filter) = match tx_status_filter { + Some(status) => (true, status), + None => (false, None), + }; + Ok(sqlx::query_as( r#" SELECT t.id as tx_id, t.tx_to as to, t.data, t.value, t.gas_limit, t.nonce, @@ -763,9 +769,12 @@ impl Database { LEFT JOIN sent_transactions s ON t.id = s.tx_id LEFT JOIN tx_hashes h ON s.valid_tx_hash = h.tx_hash WHERE t.relayer_id = $1 + AND ($2 = true AND s.status = $3) OR $2 = false "#, ) .bind(relayer_id) + .bind(should_filter) + .bind(status_filter) .fetch_all(&self.pool) .await?) } @@ -1310,6 +1319,9 @@ mod tests { assert_eq!(tx.nonce, 0); assert_eq!(tx.tx_hash, None); + let unsent_txs = db.read_txs(relayer_id, None).await?; + assert_eq!(unsent_txs.len(), 1, "1 unsent tx"); + let tx_hash_1 = H256::from_low_u64_be(1); let tx_hash_2 = H256::from_low_u64_be(2); let initial_max_fee_per_gas = U256::from(1); @@ -1328,6 +1340,18 @@ mod tests { assert_eq!(tx.tx_hash.unwrap().0, tx_hash_1); assert_eq!(tx.status, Some(TxStatus::Pending)); + let unsent_txs = db.read_txs(relayer_id, Some(None)).await?; + assert_eq!(unsent_txs.len(), 0, "0 unsent tx"); + + let pending_txs = db + .read_txs(relayer_id, Some(Some(TxStatus::Pending))) + .await?; + assert_eq!(pending_txs.len(), 1, "1 pending tx"); + + let all_txs = db.read_txs(relayer_id, None).await?; + + assert_eq!(all_txs, pending_txs); + db.escalate_tx( tx_id, tx_hash_2, diff --git a/src/db/data.rs b/src/db/data.rs index 7e8bd80..058feb0 100644 --- a/src/db/data.rs +++ b/src/db/data.rs @@ -43,7 +43,7 @@ pub struct TxForEscalation { pub escalation_count: usize, } -#[derive(Debug, Clone, FromRow)] +#[derive(Debug, Clone, FromRow, PartialEq, Eq)] pub struct ReadTxData { pub tx_id: String, pub to: AddressWrapper, diff --git a/src/server/routes/transaction.rs b/src/server/routes/transaction.rs index a1dbca7..33f7d9e 100644 --- a/src/server/routes/transaction.rs +++ b/src/server/routes/transaction.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::{Json, Path, State}; +use axum::extract::{Json, Path, Query, State}; use ethers::types::{Address, Bytes, H256, U256}; use eyre::Result; use serde::{Deserialize, Serialize}; @@ -33,6 +33,13 @@ pub struct SendTxResponse { pub tx_id: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetTxQuery { + #[serde(default)] + pub status: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetTxResponse { @@ -104,12 +111,23 @@ pub async fn send_tx( pub async fn get_txs( State(app): State>, Path(api_token): Path, + Query(query): Query, ) -> Result>, ApiError> { if !app.is_authorized(&api_token).await? { return Err(ApiError::Unauthorized); } - let txs = app.db.read_txs(&api_token.relayer_id).await?; + let txs = match query.status { + Some(GetTxResponseStatus::TxStatus(status)) => { + app.db + .read_txs(&api_token.relayer_id, Some(Some(status))) + .await? + } + Some(GetTxResponseStatus::Unsent(_)) => { + app.db.read_txs(&api_token.relayer_id, Some(None)).await? + } + None => app.db.read_txs(&api_token.relayer_id, None).await?, + }; let txs = txs.into_iter()