From 5347bc3b6293ff4c8fbf325ae322b352aabc9f0c Mon Sep 17 00:00:00 2001 From: Fabricio Date: Mon, 30 Sep 2024 22:35:07 -0600 Subject: [PATCH 01/21] Modify DevApi trait with new endpoints --- crates/katana/rpc/rpc-api/src/dev.rs | 11 +++++++++++ crates/katana/rpc/rpc/src/dev.rs | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index bcc404ffa1..04c3600272 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -24,4 +24,15 @@ pub trait DevApi { #[method(name = "predeployedAccounts")] async fn predeployed_accounts(&self) -> RpcResult>; + + #[method(name = "accountBalance")] + async fn account_balance (&self, contract_address: Felt) -> RpcResult; + + #[method(name = "feeToken")] + async fn fee_token (&self) -> RpcResult; + + #[method(name = "mint")] + async fn mint (&self) -> RpcResult<()>; + + } diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index f2d039bcdf..2ad5204e9e 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -92,6 +92,18 @@ impl DevApiServer for DevApi { Ok(()) } + async fn account_balance(&self, _contract_address: Felt) -> Result { + Ok(1) + } + + async fn fee_token(&self,) -> Result { + Ok(1) + } + + async fn mint(&self) -> Result<(), Error> { + Ok(()) + } + #[allow(deprecated)] async fn predeployed_accounts(&self) -> Result, Error> { Ok(self.backend.config.genesis.accounts().map(|e| Account::new(*e.0, e.1)).collect()) From 75155c08436a403b0fa9981394a5877234bc9c57 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Thu, 3 Oct 2024 13:13:11 -0600 Subject: [PATCH 02/21] Implementation of account_balance method --- crates/katana/node/src/lib.rs | 1 + crates/katana/rpc/rpc-api/src/dev.rs | 10 ++++---- crates/katana/rpc/rpc/Cargo.toml | 1 + crates/katana/rpc/rpc/src/dev.rs | 35 ++++++++++++++++++++++++---- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index 55dab949f0..df50c11dea 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -299,6 +299,7 @@ pub async fn spawn( let middleware = tower::ServiceBuilder::new() .option_layer(cors) .layer(ProxyGetRequestLayer::new("/", "health")?) + .layer(ProxyGetRequestLayer::new("/account_balance", "dev_accountBalance")?) .timeout(Duration::from_secs(20)); let server = ServerBuilder::new() diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index 04c3600272..558e7579e0 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -20,19 +20,17 @@ pub trait DevApi { #[method(name = "setStorageAt")] async fn set_storage_at(&self, contract_address: Felt, key: Felt, value: Felt) - -> RpcResult<()>; + -> RpcResult<()>; #[method(name = "predeployedAccounts")] async fn predeployed_accounts(&self) -> RpcResult>; #[method(name = "accountBalance")] - async fn account_balance (&self, contract_address: Felt) -> RpcResult; + async fn account_balance(&self, account_address: &str) -> RpcResult; #[method(name = "feeToken")] - async fn fee_token (&self) -> RpcResult; + async fn fee_token(&self) -> RpcResult; #[method(name = "mint")] - async fn mint (&self) -> RpcResult<()>; - - + async fn mint(&self) -> RpcResult<()>; } diff --git a/crates/katana/rpc/rpc/Cargo.toml b/crates/katana/rpc/rpc/Cargo.toml index b836ebb139..b3c53d8fbc 100644 --- a/crates/katana/rpc/rpc/Cargo.toml +++ b/crates/katana/rpc/rpc/Cargo.toml @@ -24,6 +24,7 @@ metrics.workspace = true starknet.workspace = true thiserror.workspace = true tracing.workspace = true +url = "2.5.0" [dev-dependencies] alloy = { git = "https://github.com/alloy-rs/alloy", features = [ "contract", "network", "node-bindings", "provider-http", "providers", "signer-local" ] } diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index 2ad5204e9e..aba5a0bb65 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -4,10 +4,16 @@ use jsonrpsee::core::{async_trait, Error}; use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::ExecutorFactory; -use katana_primitives::Felt; +use katana_primitives::genesis::constant::DEFAULT_FEE_TOKEN_ADDRESS; +use katana_primitives::{address, ContractAddress, Felt}; use katana_rpc_api::dev::DevApiServer; use katana_rpc_types::account::Account; use katana_rpc_types::error::dev::DevApiError; +use starknet::core::types::{BlockId, BlockTag, FunctionCall}; +use starknet::macros::selector; +use starknet::providers::jsonrpc::HttpTransport; +use starknet::providers::{JsonRpcClient, Provider}; +use url::Url; #[allow(missing_debug_implementations)] pub struct DevApi { @@ -54,7 +60,6 @@ impl DevApi { let mut block_context_generator = self.backend.block_context_generator.write(); block_context_generator.block_timestamp_offset += offset as i64; - Ok(()) } } @@ -92,11 +97,31 @@ impl DevApiServer for DevApi { Ok(()) } - async fn account_balance(&self, _contract_address: Felt) -> Result { - Ok(1) + #[allow(deprecated)] + async fn account_balance(&self, account_address: &str) -> Result { + // let account_address = + // address!("0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114"); + let account_address = Felt::from_dec_str(&account_address).unwrap(); + let account_address = ContractAddress::from(account_address); + let url = Url::parse("http://localhost:5050").unwrap(); + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(url))); + let res = provider + .call( + FunctionCall { + contract_address: DEFAULT_FEE_TOKEN_ADDRESS.into(), + entry_point_selector: selector!("balanceOf"), + calldata: vec![account_address.into()], + }, + BlockId::Tag(BlockTag::Latest), + ) + .await; + + let balance: u128 = res.unwrap()[0].to_string().parse().unwrap(); + + Ok(balance) } - async fn fee_token(&self,) -> Result { + async fn fee_token(&self) -> Result { Ok(1) } From e2612e30192e60031b3ebd5aab60095cfca4bb66 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Thu, 3 Oct 2024 13:34:15 -0600 Subject: [PATCH 03/21] clippy.sh and rust_fmt.sh --- crates/katana/rpc/rpc-api/src/dev.rs | 2 +- crates/katana/rpc/rpc/src/dev.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index 558e7579e0..e4fce9f112 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -20,7 +20,7 @@ pub trait DevApi { #[method(name = "setStorageAt")] async fn set_storage_at(&self, contract_address: Felt, key: Felt, value: Felt) - -> RpcResult<()>; + -> RpcResult<()>; #[method(name = "predeployedAccounts")] async fn predeployed_accounts(&self) -> RpcResult>; diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index aba5a0bb65..fe9ee24b4b 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -5,7 +5,7 @@ use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::ExecutorFactory; use katana_primitives::genesis::constant::DEFAULT_FEE_TOKEN_ADDRESS; -use katana_primitives::{address, ContractAddress, Felt}; +use katana_primitives::{ContractAddress, Felt}; use katana_rpc_api::dev::DevApiServer; use katana_rpc_types::account::Account; use katana_rpc_types::error::dev::DevApiError; @@ -101,7 +101,7 @@ impl DevApiServer for DevApi { async fn account_balance(&self, account_address: &str) -> Result { // let account_address = // address!("0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114"); - let account_address = Felt::from_dec_str(&account_address).unwrap(); + let account_address = Felt::from_dec_str(account_address).unwrap(); let account_address = ContractAddress::from(account_address); let url = Url::parse("http://localhost:5050").unwrap(); let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(url))); From 9902cfe782b9f1d8e1bc4e0ec4ce3c4a08e4f777 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Fri, 11 Oct 2024 09:51:36 +0200 Subject: [PATCH 04/21] Replace RPC call with querying the storage directly --- crates/katana/node/src/lib.rs | 1 + crates/katana/rpc/rpc-api/src/dev.rs | 4 +-- crates/katana/rpc/rpc/Cargo.toml | 1 - crates/katana/rpc/rpc/src/dev.rs | 54 +++++++++++++--------------- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index df50c11dea..23f8ca5885 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -300,6 +300,7 @@ pub async fn spawn( .option_layer(cors) .layer(ProxyGetRequestLayer::new("/", "health")?) .layer(ProxyGetRequestLayer::new("/account_balance", "dev_accountBalance")?) + .layer(ProxyGetRequestLayer::new("/fee_token", "dev_feeToken")?) .timeout(Duration::from_secs(20)); let server = ServerBuilder::new() diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index e4fce9f112..d833b9b491 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -26,10 +26,10 @@ pub trait DevApi { async fn predeployed_accounts(&self) -> RpcResult>; #[method(name = "accountBalance")] - async fn account_balance(&self, account_address: &str) -> RpcResult; + async fn account_balance(&self) -> RpcResult; #[method(name = "feeToken")] - async fn fee_token(&self) -> RpcResult; + async fn fee_token(&self) -> RpcResult; #[method(name = "mint")] async fn mint(&self) -> RpcResult<()>; diff --git a/crates/katana/rpc/rpc/Cargo.toml b/crates/katana/rpc/rpc/Cargo.toml index b3c53d8fbc..b836ebb139 100644 --- a/crates/katana/rpc/rpc/Cargo.toml +++ b/crates/katana/rpc/rpc/Cargo.toml @@ -24,7 +24,6 @@ metrics.workspace = true starknet.workspace = true thiserror.workspace = true tracing.workspace = true -url = "2.5.0" [dev-dependencies] alloy = { git = "https://github.com/alloy-rs/alloy", features = [ "contract", "network", "node-bindings", "provider-http", "providers", "signer-local" ] } diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index fe9ee24b4b..3b0ec11ed0 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -4,16 +4,13 @@ use jsonrpsee::core::{async_trait, Error}; use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::ExecutorFactory; -use katana_primitives::genesis::constant::DEFAULT_FEE_TOKEN_ADDRESS; -use katana_primitives::{ContractAddress, Felt}; +use katana_primitives::genesis::constant::ERC20_NAME_STORAGE_SLOT; +use katana_primitives::{address, ContractAddress, Felt}; +use katana_provider::traits::state::StateFactoryProvider; use katana_rpc_api::dev::DevApiServer; use katana_rpc_types::account::Account; use katana_rpc_types::error::dev::DevApiError; -use starknet::core::types::{BlockId, BlockTag, FunctionCall}; -use starknet::macros::selector; -use starknet::providers::jsonrpc::HttpTransport; -use starknet::providers::{JsonRpcClient, Provider}; -use url::Url; +use starknet::core::utils::get_storage_var_address; #[allow(missing_debug_implementations)] pub struct DevApi { @@ -98,31 +95,30 @@ impl DevApiServer for DevApi { } #[allow(deprecated)] - async fn account_balance(&self, account_address: &str) -> Result { - // let account_address = - // address!("0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114"); - let account_address = Felt::from_dec_str(account_address).unwrap(); - let account_address = ContractAddress::from(account_address); - let url = Url::parse("http://localhost:5050").unwrap(); - let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(url))); - let res = provider - .call( - FunctionCall { - contract_address: DEFAULT_FEE_TOKEN_ADDRESS.into(), - entry_point_selector: selector!("balanceOf"), - calldata: vec![account_address.into()], - }, - BlockId::Tag(BlockTag::Latest), - ) - .await; - - let balance: u128 = res.unwrap()[0].to_string().parse().unwrap(); - + async fn account_balance(&self) -> Result { + let account_address = + address!("0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114"); //This is temp + let provider = self.backend.blockchain.provider(); + let state = provider.latest().unwrap(); + let storage_slot = + get_storage_var_address("ERC20_balances", &[account_address.into()]).unwrap(); + let balance_felt = state + .storage(self.backend.config.genesis.fee_token.address, storage_slot) + .unwrap() + .unwrap(); + let balance: u128 = balance_felt.to_string().parse().unwrap(); Ok(balance) } - async fn fee_token(&self) -> Result { - Ok(1) + #[allow(deprecated)] + async fn fee_token(&self) -> Result { + let provider = self.backend.blockchain.provider(); + let state = provider.latest().unwrap(); + let fee_token = state + .storage(self.backend.config.genesis.fee_token.address, ERC20_NAME_STORAGE_SLOT) + .unwrap() + .unwrap(); + Ok(fee_token.to_string()) } async fn mint(&self) -> Result<(), Error> { From b2e3ad47a1b63b47997a08047a6f68f62c49b0b2 Mon Sep 17 00:00:00 2001 From: Fabricio Date: Tue, 22 Oct 2024 17:46:03 -0600 Subject: [PATCH 05/21] -Re implement proxy_get_request.rs -Send params to RPC method --- crates/katana/node/src/lib.rs | 9 +- crates/katana/rpc/rpc-api/src/dev.rs | 4 +- crates/katana/rpc/rpc/Cargo.toml | 24 +- crates/katana/rpc/rpc/src/dev.rs | 9 +- crates/katana/rpc/rpc/src/future.rs | 220 +++++ crates/katana/rpc/rpc/src/lib.rs | 5 + crates/katana/rpc/rpc/src/logger.rs | 191 ++++ .../katana/rpc/rpc/src/proxy_get_request.rs | 163 ++++ crates/katana/rpc/rpc/src/server.rs | 886 ++++++++++++++++++ crates/katana/rpc/rpc/src/transport/http.rs | 502 ++++++++++ crates/katana/rpc/rpc/src/transport/mod.rs | 2 + crates/katana/rpc/rpc/src/transport/ws.rs | 613 ++++++++++++ 12 files changed, 2616 insertions(+), 12 deletions(-) create mode 100644 crates/katana/rpc/rpc/src/future.rs create mode 100644 crates/katana/rpc/rpc/src/logger.rs create mode 100644 crates/katana/rpc/rpc/src/proxy_get_request.rs create mode 100644 crates/katana/rpc/rpc/src/server.rs create mode 100644 crates/katana/rpc/rpc/src/transport/http.rs create mode 100644 crates/katana/rpc/rpc/src/transport/mod.rs create mode 100644 crates/katana/rpc/rpc/src/transport/ws.rs diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index 23f8ca5885..62024e475a 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -7,7 +7,8 @@ use std::time::Duration; use anyhow::Result; use dojo_metrics::{metrics_process, prometheus_exporter, Report}; use hyper::{Method, Uri}; -use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; +// use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; +use katana_rpc::proxy_get_request::DevnetProxyLayer; use jsonrpsee::server::{AllowHosts, ServerBuilder, ServerHandle}; use jsonrpsee::RpcModule; use katana_core::backend::config::StarknetConfig; @@ -298,9 +299,9 @@ pub async fn spawn( let middleware = tower::ServiceBuilder::new() .option_layer(cors) - .layer(ProxyGetRequestLayer::new("/", "health")?) - .layer(ProxyGetRequestLayer::new("/account_balance", "dev_accountBalance")?) - .layer(ProxyGetRequestLayer::new("/fee_token", "dev_feeToken")?) + .layer(DevnetProxyLayer::new("/", "health")?) + .layer(DevnetProxyLayer::new("/account_balance", "dev_accountBalance")?) + .layer(DevnetProxyLayer::new("/fee_token", "dev_feeToken")?) .timeout(Duration::from_secs(20)); let server = ServerBuilder::new() diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index d833b9b491..61889c1007 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -20,13 +20,13 @@ pub trait DevApi { #[method(name = "setStorageAt")] async fn set_storage_at(&self, contract_address: Felt, key: Felt, value: Felt) - -> RpcResult<()>; + -> RpcResult<()>; #[method(name = "predeployedAccounts")] async fn predeployed_accounts(&self) -> RpcResult>; #[method(name = "accountBalance")] - async fn account_balance(&self) -> RpcResult; + async fn account_balance(&self, address: String) -> RpcResult; #[method(name = "feeToken")] async fn fee_token(&self) -> RpcResult; diff --git a/crates/katana/rpc/rpc/Cargo.toml b/crates/katana/rpc/rpc/Cargo.toml index b836ebb139..42f904fd2f 100644 --- a/crates/katana/rpc/rpc/Cargo.toml +++ b/crates/katana/rpc/rpc/Cargo.toml @@ -11,6 +11,11 @@ anyhow.workspace = true dojo-metrics.workspace = true futures.workspace = true jsonrpsee = { workspace = true, features = [ "server" ] } +jsonrpsee-core = { version = "0.16.3", features = [ "server", "soketto", "http-helpers" ] } +jsonrpsee-types = { version = "0.16.3"} +hyper.workspace = true +tower = { workspace = true, features = [ "full" ] } +http = { version = "0.2.7" } katana-core.workspace = true katana-executor.workspace = true katana-pool.workspace = true @@ -24,6 +29,23 @@ metrics.workspace = true starknet.workspace = true thiserror.workspace = true tracing.workspace = true +serde.workspace = true +serde_json.workspace = true +soketto = { version = "0.7.1", features = ["http"] } +tokio = { version = "1.16", features = [ + "net", + "rt-multi-thread", + "macros", + "time", +]} +futures-channel = { version = "0.3.14"} +futures-util = { version = "0.3.14", features = [ + "io", + "async-await-macro", +]} +tokio-stream = { version = "0.1.7" } +tokio-util = { version = "0.7", features = ["compat"]} +starknet-crypto.workspace = true [dev-dependencies] alloy = { git = "https://github.com/alloy-rs/alloy", features = [ "contract", "network", "node-bindings", "provider-http", "providers", "signer-local" ] } @@ -42,8 +64,6 @@ katana-runner.workspace = true rstest.workspace = true num-traits.workspace = true rand.workspace = true -serde.workspace = true -serde_json.workspace = true tempfile.workspace = true tokio.workspace = true url.workspace = true diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index 3b0ec11ed0..82e31736ad 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -1,3 +1,4 @@ +use std::str::FromStr; use std::sync::Arc; use jsonrpsee::core::{async_trait, Error}; @@ -5,7 +6,8 @@ use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::ExecutorFactory; use katana_primitives::genesis::constant::ERC20_NAME_STORAGE_SLOT; -use katana_primitives::{address, ContractAddress, Felt}; +use katana_primitives::ContractAddress; +use starknet_crypto::Felt; use katana_provider::traits::state::StateFactoryProvider; use katana_rpc_api::dev::DevApiServer; use katana_rpc_types::account::Account; @@ -95,9 +97,8 @@ impl DevApiServer for DevApi { } #[allow(deprecated)] - async fn account_balance(&self) -> Result { - let account_address = - address!("0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114"); //This is temp + async fn account_balance(&self, address: String) -> Result { + let account_address: ContractAddress = Felt::from_str(&address).unwrap().into(); let provider = self.backend.blockchain.provider(); let state = provider.latest().unwrap(); let storage_slot = diff --git a/crates/katana/rpc/rpc/src/future.rs b/crates/katana/rpc/rpc/src/future.rs new file mode 100644 index 0000000000..15ec7e29d0 --- /dev/null +++ b/crates/katana/rpc/rpc/src/future.rs @@ -0,0 +1,220 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Utilities for handling async code. + +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures_util::future::FutureExt; +use jsonrpsee_core::Error; +use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::time::{self, Duration, Interval}; + +/// Polling for server stop monitor interval in milliseconds. +const STOP_MONITOR_POLLING_INTERVAL: Duration = Duration::from_millis(1000); + +/// This is a flexible collection of futures that need to be driven to completion +/// alongside some other future, such as connection handlers that need to be +/// handled along with a listener for new connections. +/// +/// In order to `.await` on these futures and drive them to completion, call +/// `select_with` providing some other future, the result of which you need. +pub(crate) struct FutureDriver { + futures: Vec, + stop_monitor_heartbeat: Interval, +} + +impl Default for FutureDriver { + fn default() -> Self { + let mut heartbeat = time::interval(STOP_MONITOR_POLLING_INTERVAL); + + heartbeat.set_missed_tick_behavior(time::MissedTickBehavior::Skip); + + FutureDriver { futures: Vec::new(), stop_monitor_heartbeat: heartbeat } + } +} + +impl FutureDriver { + /// Add a new future to this driver + pub(crate) fn add(&mut self, future: F) { + self.futures.push(future); + } +} + +impl FutureDriver +where + F: Future + Unpin, +{ + pub(crate) async fn select_with(&mut self, selector: S) -> S::Output { + tokio::pin!(selector); + + DriverSelect { selector, driver: self }.await + } + + fn drive(&mut self, cx: &mut Context) { + let mut i = 0; + + while i < self.futures.len() { + if self.futures[i].poll_unpin(cx).is_ready() { + // Using `swap_remove` since we don't care about ordering + // but we do care about removing being `O(1)`. + // + // We don't increment `i` in this branch, since we now + // have a shorter length, and potentially a new value at + // current index + self.futures.swap_remove(i); + } else { + i += 1; + } + } + } + + fn poll_stop_monitor_heartbeat(&mut self, cx: &mut Context) { + // We don't care about the ticks of the heartbeat, it's here only + // to periodically wake the `Waker` on `cx`. + let _ = self.stop_monitor_heartbeat.poll_tick(cx); + } +} + +impl Future for FutureDriver +where + F: Future + Unpin, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = Pin::into_inner(self); + + this.drive(cx); + + if this.futures.is_empty() { + Poll::Ready(()) + } else { + Poll::Pending + } + } +} + +/// This is a glorified select `Future` that will attempt to drive all +/// connection futures `F` to completion on each `poll`, while also +/// handling incoming connections. +struct DriverSelect<'a, S, F> { + selector: S, + driver: &'a mut FutureDriver, +} + +impl<'a, R, F> Future for DriverSelect<'a, R, F> +where + R: Future + Unpin, + F: Future + Unpin, +{ + type Output = R::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = Pin::into_inner(self); + + this.driver.drive(cx); + this.driver.poll_stop_monitor_heartbeat(cx); + + this.selector.poll_unpin(cx) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct StopHandle(watch::Receiver<()>); + +impl StopHandle { + pub(crate) fn new(rx: watch::Receiver<()>) -> Self { + Self(rx) + } + + pub(crate) fn shutdown_requested(&self) -> bool { + // if a message has been seen, it means that `stop` has been called. + self.0.has_changed().unwrap_or(true) + } + + pub(crate) async fn shutdown(&mut self) { + // Err(_) implies that the `sender` has been dropped. + // Ok(_) implies that `stop` has been called. + let _ = self.0.changed().await; + } +} + +/// Server handle. +/// +/// When all [`StopHandle`]'s have been `dropped` or `stop` has been called +/// the server will be stopped. +#[derive(Debug, Clone)] +pub struct ServerHandle(Arc>); + +impl ServerHandle { + /// Create a new server handle. + pub fn new(tx: watch::Sender<()>) -> Self { + Self(Arc::new(tx)) + } + + /// Tell the server to stop without waiting for the server to stop. + pub fn stop(&self) -> Result<(), Error> { + self.0.send(()).map_err(|_| Error::AlreadyStopped) + } + + /// Wait for the server to stop. + pub async fn stopped(self) { + self.0.closed().await + } + + /// Check if the server has been stopped. + pub fn is_stopped(&self) -> bool { + self.0.is_closed() + } +} + +/// Limits the number of connections. +#[derive(Debug)] +pub(crate) struct ConnectionGuard(Arc); + +impl ConnectionGuard { + pub(crate) fn new(limit: usize) -> Self { + Self(Arc::new(Semaphore::new(limit))) + } + + pub(crate) fn try_acquire(&self) -> Option { + match self.0.clone().try_acquire_owned() { + Ok(guard) => Some(guard), + Err(TryAcquireError::Closed) => { + unreachable!("Semaphore::Close is never called and can't be closed; qed") + } + Err(TryAcquireError::NoPermits) => None, + } + } + + pub(crate) fn available_connections(&self) -> usize { + self.0.available_permits() + } +} diff --git a/crates/katana/rpc/rpc/src/lib.rs b/crates/katana/rpc/rpc/src/lib.rs index 6abe5d449e..5c56b2d0a6 100644 --- a/crates/katana/rpc/rpc/src/lib.rs +++ b/crates/katana/rpc/rpc/src/lib.rs @@ -9,5 +9,10 @@ pub mod metrics; pub mod saya; pub mod starknet; pub mod torii; +pub mod proxy_get_request; mod utils; +mod transport; +mod future; +mod logger; +mod server; diff --git a/crates/katana/rpc/rpc/src/logger.rs b/crates/katana/rpc/rpc/src/logger.rs new file mode 100644 index 0000000000..53e5eb4477 --- /dev/null +++ b/crates/katana/rpc/rpc/src/logger.rs @@ -0,0 +1,191 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Logger for `jsonrpsee` servers. + +use std::net::SocketAddr; + +/// HTTP request. +pub type HttpRequest = http::Request; +pub use hyper::Body; +pub use jsonrpsee_types::Params; + +/// The type JSON-RPC v2 call, it can be a subscription, method call or unknown. +#[derive(Debug, Copy, Clone)] +pub enum MethodKind { + /// Subscription Call. + Subscription, + /// Unsubscription Call. + Unsubscription, + /// Method call. + MethodCall, + /// Unknown method. + Unknown, +} + +impl std::fmt::Display for MethodKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Self::Subscription => "subscription", + Self::MethodCall => "method call", + Self::Unknown => "unknown", + Self::Unsubscription => "unsubscription", + }; + + write!(f, "{}", s) + } +} + +/// The transport protocol used to send or receive a call or request. +#[derive(Debug, Copy, Clone)] +pub enum TransportProtocol { + /// HTTP transport. + Http, + /// WebSocket transport. + WebSocket, +} + +impl std::fmt::Display for TransportProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Self::Http => "http", + Self::WebSocket => "websocket", + }; + + write!(f, "{}", s) + } +} + +/// Defines a logger specifically for WebSocket connections with callbacks during the RPC request life-cycle. +/// The primary use case for this is to collect timings for a larger metrics collection solution. +/// +/// See the [`ServerBuilder::set_logger`](../../jsonrpsee_server/struct.ServerBuilder.html#method.set_logger) +/// for examples. +pub trait Logger: Send + Sync + Clone + 'static { + /// Intended to carry timestamp of a request, for example `std::time::Instant`. How the trait + /// measures time, if at all, is entirely up to the implementation. + type Instant: std::fmt::Debug + Send + Sync + Copy; + + /// Called when a new client connects + fn on_connect(&self, _remote_addr: SocketAddr, _request: &HttpRequest, _t: TransportProtocol); + + /// Called when a new JSON-RPC request comes to the server. + fn on_request(&self, transport: TransportProtocol) -> Self::Instant; + + /// Called on each JSON-RPC method call, batch requests will trigger `on_call` multiple times. + fn on_call( + &self, + method_name: &str, + params: Params, + kind: MethodKind, + transport: TransportProtocol, + ); + + /// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple times. + fn on_result( + &self, + method_name: &str, + success: bool, + started_at: Self::Instant, + transport: TransportProtocol, + ); + + /// Called once the JSON-RPC request is finished and response is sent to the output buffer. + fn on_response(&self, result: &str, started_at: Self::Instant, transport: TransportProtocol); + + /// Called when a client disconnects + fn on_disconnect(&self, _remote_addr: SocketAddr, transport: TransportProtocol); +} + +impl Logger for () { + type Instant = (); + + fn on_connect(&self, _: SocketAddr, _: &HttpRequest, _p: TransportProtocol) -> Self::Instant {} + + fn on_request(&self, _p: TransportProtocol) -> Self::Instant {} + + fn on_call(&self, _: &str, _: Params, _: MethodKind, _p: TransportProtocol) {} + + fn on_result(&self, _: &str, _: bool, _: Self::Instant, _p: TransportProtocol) {} + + fn on_response(&self, _: &str, _: Self::Instant, _p: TransportProtocol) {} + + fn on_disconnect(&self, _: SocketAddr, _p: TransportProtocol) {} +} + +impl Logger for (A, B) +where + A: Logger, + B: Logger, +{ + type Instant = (A::Instant, B::Instant); + + fn on_connect( + &self, + remote_addr: std::net::SocketAddr, + request: &HttpRequest, + transport: TransportProtocol, + ) { + self.0.on_connect(remote_addr, request, transport); + self.1.on_connect(remote_addr, request, transport); + } + + fn on_request(&self, transport: TransportProtocol) -> Self::Instant { + (self.0.on_request(transport), self.1.on_request(transport)) + } + + fn on_call( + &self, + method_name: &str, + params: Params, + kind: MethodKind, + transport: TransportProtocol, + ) { + self.0.on_call(method_name, params.clone(), kind, transport); + self.1.on_call(method_name, params, kind, transport); + } + + fn on_result( + &self, + method_name: &str, + success: bool, + started_at: Self::Instant, + transport: TransportProtocol, + ) { + self.0.on_result(method_name, success, started_at.0, transport); + self.1.on_result(method_name, success, started_at.1, transport); + } + + fn on_response(&self, result: &str, started_at: Self::Instant, transport: TransportProtocol) { + self.0.on_response(result, started_at.0, transport); + self.1.on_response(result, started_at.1, transport); + } + + fn on_disconnect(&self, remote_addr: SocketAddr, transport: TransportProtocol) { + self.0.on_disconnect(remote_addr, transport); + self.1.on_disconnect(remote_addr, transport); + } +} diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs new file mode 100644 index 0000000000..9b55ed9e84 --- /dev/null +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -0,0 +1,163 @@ +//! Middleware that proxies requests at a specified URI to internal +//! RPC method calls. + +use crate::transport::http; +use hyper::body; +use hyper::header::{ACCEPT, CONTENT_TYPE}; +use hyper::http::HeaderValue; +use hyper::{Body, Method, Request, Response, Uri}; +use jsonrpsee_core::error::Error as RpcError; +use jsonrpsee_core::JsonRawValue; +use jsonrpsee_types::{Id, RequestSer}; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tower::{Layer, Service}; + +/// Layer that applies [`DevnetProxy`] which proxies the `GET /path` requests to +/// specific RPC method calls and that strips the response. +/// +/// See [`DevnetProxy`] for more details. +#[derive(Debug, Clone)] +pub struct DevnetProxyLayer { + path: String, + method: String, +} + +impl DevnetProxyLayer { + /// Creates a new [`DevnetProxyLayer`]. + /// + /// See [`DevnetProxy`] for more details. + pub fn new(path: impl Into, method: impl Into) -> Result { + let path = path.into(); + if !path.starts_with('/') { + return Err(RpcError::Custom( + "DevnetProxyLayer path must start with `/`".to_string(), + )); + } + + Ok(Self { path, method: method.into() }) + } +} +impl Layer for DevnetProxyLayer { + type Service = DevnetProxy; + + fn layer(&self, inner: S) -> Self::Service { + DevnetProxy::new(inner, &self.path, &self.method) + .expect("Path already validated in DevnetProxyLayer; qed") + } +} + +/// Proxy `GET /path` requests to the specified RPC method calls. +/// +/// # Request +/// +/// The `GET /path` requests are modified into valid `POST` requests for +/// calling the RPC method. This middleware adds appropriate headers to the +/// request, and completely modifies the request `BODY`. +/// +/// # Response +/// +/// The response of the RPC method is stripped down to contain only the method's +/// response, removing any RPC 2.0 spec logic regarding the response' body. +#[derive(Debug, Clone)] +pub struct DevnetProxy { + inner: S, + path: Arc, + method: Arc, +} + +impl DevnetProxy { + /// Creates a new [`DevnetProxy`]. + /// + /// The request `GET /path` is redirected to the provided method. + /// Fails if the path does not start with `/`. + pub fn new(inner: S, path: &str, method: &str) -> Result { + if !path.starts_with('/') { + return Err(RpcError::Custom(format!( + "DevnetProxy path must start with `/`, got: {}", + path + ))); + } + + Ok(Self { inner, path: Arc::from(path), method: Arc::from(method) }) + } +} + +impl Service> for DevnetProxy +where + S: Service, Response = Response>, + S::Response: 'static, + S::Error: Into> + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = Box; + type Future = + Pin> + Send + 'static>>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let modify = self.path.as_ref() == req.uri() && req.method() == Method::GET; + + // Proxy the request to the appropriate method call. + if modify { + // RPC methods are accessed with `POST`. + *req.method_mut() = Method::POST; + // Precautionary remove the URI. + *req.uri_mut() = Uri::from_static("/"); + + // Requests must have the following headers: + req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); + + // Adjust the body to reflect the method call. + let raw_value = JsonRawValue::from_string("{\"address\":\"0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114\", \"age\":5, \"name\":\"somename\"}".to_string()).unwrap(); + let param = Some(raw_value.as_ref()); + + let body = Body::from( + serde_json::to_string(&RequestSer::borrowed(&Id::Number(0), &self.method, param)) + .expect("Valid request; qed"), + ); + req = req.map(|_| body); + } + + // Call the inner service and get a future that resolves to the response. + let fut = self.inner.call(req); + + // Adjust the response if needed. + let res_fut = async move { + let res = fut.await.map_err(|err| err.into())?; + + // Nothing to modify: return the response as is. + if !modify { + return Ok(res); + } + + let body = res.into_body(); + let bytes = hyper::body::to_bytes(body).await?; + + #[derive(serde::Deserialize, Debug)] + struct RpcPayload<'a> { + #[serde(borrow)] + result: &'a serde_json::value::RawValue, + } + + let response = if let Ok(payload) = serde_json::from_slice::(&bytes) { + http::response::ok_response(payload.result.to_string()) + } else { + http::response::internal_error() + }; + + Ok(response) + }; + + Box::pin(res_fut) + } +} diff --git a/crates/katana/rpc/rpc/src/server.rs b/crates/katana/rpc/rpc/src/server.rs new file mode 100644 index 0000000000..c17aefe350 --- /dev/null +++ b/crates/katana/rpc/rpc/src/server.rs @@ -0,0 +1,886 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::error::Error as StdError; +use std::future::Future; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use crate::future::{ConnectionGuard, FutureDriver, ServerHandle, StopHandle}; +use crate::logger::{Logger, TransportProtocol}; +use crate::transport::{http, ws}; + +use futures_util::future::{BoxFuture, FutureExt}; +use futures_util::io::{BufReader, BufWriter}; + +use hyper::body::HttpBody; +use jsonrpsee_core::id_providers::RandomIntegerIdProvider; + +use jsonrpsee_core::server::helpers::MethodResponse; +use jsonrpsee_core::server::host_filtering::AllowHosts; +use jsonrpsee_core::server::resource_limiting::Resources; +use jsonrpsee_core::server::rpc_module::Methods; +use jsonrpsee_core::traits::IdProvider; +use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES}; + +use soketto::handshake::http::is_upgrade_request; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use tokio::sync::{watch, OwnedSemaphorePermit}; +use tokio_util::compat::TokioAsyncReadCompatExt; +use tower::layer::util::Identity; +use tower::{Layer, Service}; +use tracing::{instrument, Instrument}; + +/// Default maximum connections allowed. +const MAX_CONNECTIONS: u32 = 100; + +/// JSON RPC server. +pub struct Server { + listener: TcpListener, + cfg: Settings, + resources: Resources, + logger: L, + id_provider: Arc, + service_builder: tower::ServiceBuilder, +} + +impl std::fmt::Debug for Server { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Server") + .field("listener", &self.listener) + .field("cfg", &self.cfg) + .field("id_provider", &self.id_provider) + .field("resources", &self.resources) + .finish() + } +} + +impl Server { + /// Returns socket address to which the server is bound. + pub fn local_addr(&self) -> Result { + self.listener.local_addr().map_err(Into::into) + } +} + +impl Server +where + L: Logger, + B: Layer> + Send + 'static, + >>::Service: Send + + Service< + hyper::Request, + Response = hyper::Response, + Error = Box<(dyn StdError + Send + Sync + 'static)>, + >, + <>>::Service as Service>>::Future: Send, + U: HttpBody + Send + 'static, + ::Error: Send + Sync + StdError, + ::Data: Send, +{ + /// Start responding to connections requests. + /// + /// This will run on the tokio runtime until the server is stopped or the `ServerHandle` is dropped. + pub fn start(mut self, methods: impl Into) -> Result { + let methods = methods.into().initialize_resources(&self.resources)?; + let (stop_tx, stop_rx) = watch::channel(()); + + let stop_handle = StopHandle::new(stop_rx); + + match self.cfg.tokio_runtime.take() { + Some(rt) => rt.spawn(self.start_inner(methods, stop_handle)), + None => tokio::spawn(self.start_inner(methods, stop_handle)), + }; + + Ok(ServerHandle::new(stop_tx)) + } + + async fn start_inner(self, methods: Methods, stop_handle: StopHandle) { + let max_request_body_size = self.cfg.max_request_body_size; + let max_response_body_size = self.cfg.max_response_body_size; + let max_log_length = self.cfg.max_log_length; + let allow_hosts = self.cfg.allow_hosts; + let resources = self.resources; + let logger = self.logger; + let batch_requests_supported = self.cfg.batch_requests_supported; + let id_provider = self.id_provider; + let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection; + + let mut id: u32 = 0; + let connection_guard = ConnectionGuard::new(self.cfg.max_connections as usize); + let mut connections = FutureDriver::default(); + let mut incoming = Monitored::new(Incoming(self.listener), &stop_handle); + + loop { + match connections.select_with(&mut incoming).await { + Ok((socket, remote_addr)) => { + let data = ProcessConnection { + remote_addr, + methods: methods.clone(), + allow_hosts: allow_hosts.clone(), + resources: resources.clone(), + max_request_body_size, + max_response_body_size, + max_log_length, + batch_requests_supported, + id_provider: id_provider.clone(), + ping_interval: self.cfg.ping_interval, + stop_handle: stop_handle.clone(), + max_subscriptions_per_connection, + conn_id: id, + logger: logger.clone(), + max_connections: self.cfg.max_connections, + enable_http: self.cfg.enable_http, + enable_ws: self.cfg.enable_ws, + }; + process_connection( + &self.service_builder, + &connection_guard, + data, + socket, + &mut connections, + ); + id = id.wrapping_add(1); + } + Err(MonitoredError::Selector(err)) => { + tracing::error!("Error while awaiting a new connection: {:?}", err); + } + Err(MonitoredError::Shutdown) => break, + } + } + + connections.await; + } +} + +/// JSON-RPC Websocket server settings. +#[derive(Debug, Clone)] +struct Settings { + /// Maximum size in bytes of a request. + max_request_body_size: u32, + /// Maximum size in bytes of a response. + max_response_body_size: u32, + /// Maximum number of incoming connections allowed. + max_connections: u32, + /// Maximum number of subscriptions per connection. + max_subscriptions_per_connection: u32, + /// Max length for logging for requests and responses + /// + /// Logs bigger than this limit will be truncated. + max_log_length: u32, + /// Host filtering. + allow_hosts: AllowHosts, + /// Whether batch requests are supported by this server or not. + batch_requests_supported: bool, + /// Custom tokio runtime to run the server on. + tokio_runtime: Option, + /// The interval at which `Ping` frames are submitted. + ping_interval: Duration, + /// Enable HTTP. + enable_http: bool, + /// Enable WS. + enable_ws: bool, +} + +impl Default for Settings { + fn default() -> Self { + Self { + max_request_body_size: TEN_MB_SIZE_BYTES, + max_response_body_size: TEN_MB_SIZE_BYTES, + max_log_length: 4096, + max_subscriptions_per_connection: 1024, + max_connections: MAX_CONNECTIONS, + batch_requests_supported: true, + allow_hosts: AllowHosts::Any, + tokio_runtime: None, + ping_interval: Duration::from_secs(60), + enable_http: true, + enable_ws: true, + } + } +} + +/// Builder to configure and create a JSON-RPC server +#[derive(Debug)] +pub struct Builder { + settings: Settings, + resources: Resources, + logger: L, + id_provider: Arc, + service_builder: tower::ServiceBuilder, +} + +impl Default for Builder { + fn default() -> Self { + Builder { + settings: Settings::default(), + resources: Resources::default(), + logger: (), + id_provider: Arc::new(RandomIntegerIdProvider), + service_builder: tower::ServiceBuilder::new(), + } + } +} + +impl Builder { + /// Create a default server builder. + pub fn new() -> Self { + Self::default() + } +} + +impl Builder { + /// Set the maximum size of a request body in bytes. Default is 10 MiB. + pub fn max_request_body_size(mut self, size: u32) -> Self { + self.settings.max_request_body_size = size; + self + } + + /// Set the maximum size of a response body in bytes. Default is 10 MiB. + pub fn max_response_body_size(mut self, size: u32) -> Self { + self.settings.max_response_body_size = size; + self + } + + /// Set the maximum number of connections allowed. Default is 100. + pub fn max_connections(mut self, max: u32) -> Self { + self.settings.max_connections = max; + self + } + + /// Enables or disables support of [batch requests](https://www.jsonrpc.org/specification#batch). + /// By default, support is enabled. + pub fn batch_requests_supported(mut self, supported: bool) -> Self { + self.settings.batch_requests_supported = supported; + self + } + + /// Set the maximum number of connections allowed. Default is 1024. + pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self { + self.settings.max_subscriptions_per_connection = max; + self + } + + /// Register a new resource kind. Errors if `label` is already registered, or if the number of + /// registered resources on this server instance would exceed 8. + /// + /// See the module documentation for [`resurce_limiting`](../jsonrpsee_utils/server/resource_limiting/index.html#resource-limiting) + /// for details. + pub fn register_resource( + mut self, + label: &'static str, + capacity: u16, + default: u16, + ) -> Result { + self.resources.register(label, capacity, default)?; + Ok(self) + } + + /// Add a logger to the builder [`Logger`](../jsonrpsee_core/logger/trait.Logger.html). + /// + /// ``` + /// use std::{time::Instant, net::SocketAddr}; + /// + /// use jsonrpsee_server::logger::{Logger, HttpRequest, MethodKind, Params, TransportProtocol}; + /// use jsonrpsee_server::ServerBuilder; + /// + /// #[derive(Clone)] + /// struct MyLogger; + /// + /// impl Logger for MyLogger { + /// type Instant = Instant; + /// + /// fn on_connect(&self, remote_addr: SocketAddr, request: &HttpRequest, transport: TransportProtocol) { + /// println!("[MyLogger::on_call] remote_addr: {:?}, headers: {:?}, transport: {}", remote_addr, request, transport); + /// } + /// + /// fn on_request(&self, transport: TransportProtocol) -> Self::Instant { + /// Instant::now() + /// } + /// + /// fn on_call(&self, method_name: &str, params: Params, kind: MethodKind, transport: TransportProtocol) { + /// println!("[MyLogger::on_call] method: '{}' params: {:?}, kind: {:?}, transport: {}", method_name, params, kind, transport); + /// } + /// + /// fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant, transport: TransportProtocol) { + /// println!("[MyLogger::on_result] '{}', worked? {}, time elapsed {:?}, transport: {}", method_name, success, started_at.elapsed(), transport); + /// } + /// + /// fn on_response(&self, result: &str, started_at: Self::Instant, transport: TransportProtocol) { + /// println!("[MyLogger::on_response] result: {}, time elapsed {:?}, transport: {}", result, started_at.elapsed(), transport); + /// } + /// + /// fn on_disconnect(&self, remote_addr: SocketAddr, transport: TransportProtocol) { + /// println!("[MyLogger::on_disconnect] remote_addr: {:?}, transport: {}", remote_addr, transport); + /// } + /// } + /// + /// let builder = ServerBuilder::new().set_logger(MyLogger); + /// ``` + pub fn set_logger(self, logger: T) -> Builder { + Builder { + settings: self.settings, + resources: self.resources, + logger, + id_provider: self.id_provider, + service_builder: self.service_builder, + } + } + + /// Configure a custom [`tokio::runtime::Handle`] to run the server on. + /// + /// Default: [`tokio::spawn`] + pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self { + self.settings.tokio_runtime = Some(rt); + self + } + + /// Configure the interval at which pings are submitted. + /// + /// This option is used to keep the connection alive, and is just submitting `Ping` frames, + /// without making any assumptions about when a `Pong` frame should be received. + /// + /// Default: 60 seconds. + /// + /// # Examples + /// + /// ```rust + /// use std::time::Duration; + /// use jsonrpsee_server::ServerBuilder; + /// + /// // Set the ping interval to 10 seconds. + /// let builder = ServerBuilder::default().ping_interval(Duration::from_secs(10)); + /// ``` + pub fn ping_interval(mut self, interval: Duration) -> Self { + self.settings.ping_interval = interval; + self + } + + /// Configure custom `subscription ID` provider for the server to use + /// to when getting new subscription calls. + /// + /// You may choose static dispatch or dynamic dispatch because + /// `IdProvider` is implemented for `Box`. + /// + /// Default: [`RandomIntegerIdProvider`]. + /// + /// # Examples + /// + /// ```rust + /// use jsonrpsee_server::{ServerBuilder, RandomStringIdProvider, IdProvider}; + /// + /// // static dispatch + /// let builder1 = ServerBuilder::default().set_id_provider(RandomStringIdProvider::new(16)); + /// + /// // or dynamic dispatch + /// let builder2 = ServerBuilder::default().set_id_provider(Box::new(RandomStringIdProvider::new(16))); + /// ``` + /// + pub fn set_id_provider(mut self, id_provider: I) -> Self { + self.id_provider = Arc::new(id_provider); + self + } + + /// Sets host filtering. + pub fn set_host_filtering(mut self, allow: AllowHosts) -> Self { + self.settings.allow_hosts = allow; + self + } + + /// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied to the RPC service. + /// + /// Default: No tower layers are applied to the RPC service. + /// + /// # Examples + /// + /// ```rust + /// + /// use std::time::Duration; + /// use std::net::SocketAddr; + /// + /// #[tokio::main] + /// async fn main() { + /// let builder = tower::ServiceBuilder::new().timeout(Duration::from_secs(2)); + /// + /// let server = jsonrpsee_server::ServerBuilder::new() + /// .set_middleware(builder) + /// .build("127.0.0.1:0".parse::().unwrap()) + /// .await + /// .unwrap(); + /// } + /// ``` + pub fn set_middleware(self, service_builder: tower::ServiceBuilder) -> Builder { + Builder { + settings: self.settings, + resources: self.resources, + logger: self.logger, + id_provider: self.id_provider, + service_builder, + } + } + + /// Configure the server to only serve JSON-RPC HTTP requests. + /// + /// Default: both http and ws are enabled. + pub fn http_only(mut self) -> Self { + self.settings.enable_http = true; + self.settings.enable_ws = false; + self + } + + /// Configure the server to only serve JSON-RPC WebSocket requests. + /// + /// That implies that server just denies HTTP requests which isn't a WebSocket upgrade request + /// + /// Default: both http and ws are enabled. + pub fn ws_only(mut self) -> Self { + self.settings.enable_http = false; + self.settings.enable_ws = true; + self + } + + /// Finalize the configuration of the server. Consumes the [`Builder`]. + /// + /// ```rust + /// #[tokio::main] + /// async fn main() { + /// let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + /// let occupied_addr = listener.local_addr().unwrap(); + /// let addrs: &[std::net::SocketAddr] = &[ + /// occupied_addr, + /// "127.0.0.1:0".parse().unwrap(), + /// ]; + /// assert!(jsonrpsee_server::ServerBuilder::default().build(occupied_addr).await.is_err()); + /// assert!(jsonrpsee_server::ServerBuilder::default().build(addrs).await.is_ok()); + /// } + /// ``` + /// + pub async fn build(self, addrs: impl ToSocketAddrs) -> Result, Error> { + let listener = TcpListener::bind(addrs).await?; + + Ok(Server { + listener, + cfg: self.settings, + resources: self.resources, + logger: self.logger, + id_provider: self.id_provider, + service_builder: self.service_builder, + }) + } + + /// Finalizes the configuration of the server with customized TCP settings on the socket. + /// + /// + /// ```rust + /// use jsonrpsee_server::ServerBuilder; + /// use socket2::{Domain, Socket, Type}; + /// use std::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let addr = "127.0.0.1:0".parse().unwrap(); + /// let domain = Domain::for_address(addr); + /// let socket = Socket::new(domain, Type::STREAM, None).unwrap(); + /// socket.set_nonblocking(true).unwrap(); + /// + /// let address = addr.into(); + /// socket.bind(&address).unwrap(); + /// + /// socket.listen(4096).unwrap(); + /// + /// let server = ServerBuilder::new().build_from_tcp(socket).unwrap(); + /// } + /// ``` + pub fn build_from_tcp( + self, + listener: impl Into, + ) -> Result, Error> { + let listener = TcpListener::from_std(listener.into())?; + + Ok(Server { + listener, + cfg: self.settings, + resources: self.resources, + logger: self.logger, + id_provider: self.id_provider, + service_builder: self.service_builder, + }) + } +} + +pub(crate) enum MethodResult { + JustLogger(MethodResponse), + SendAndLogger(MethodResponse), +} + +impl MethodResult { + pub(crate) fn as_inner(&self) -> &MethodResponse { + match &self { + Self::JustLogger(r) => r, + Self::SendAndLogger(r) => r, + } + } + + pub(crate) fn into_inner(self) -> MethodResponse { + match self { + Self::JustLogger(r) => r, + Self::SendAndLogger(r) => r, + } + } +} + +/// Data required by the server to handle requests. +#[derive(Debug, Clone)] +pub(crate) struct ServiceData { + /// Remote server address. + pub(crate) remote_addr: SocketAddr, + /// Registered server methods. + pub(crate) methods: Methods, + /// Access control. + pub(crate) allow_hosts: AllowHosts, + /// Tracker for currently used resources on the server. + pub(crate) resources: Resources, + /// Max request body size. + pub(crate) max_request_body_size: u32, + /// Max response body size. + pub(crate) max_response_body_size: u32, + /// Max length for logging for request and response + /// + /// Logs bigger than this limit will be truncated. + pub(crate) max_log_length: u32, + /// Whether batch requests are supported by this server or not. + pub(crate) batch_requests_supported: bool, + /// Subscription ID provider. + pub(crate) id_provider: Arc, + /// Ping interval + pub(crate) ping_interval: Duration, + /// Stop handle. + pub(crate) stop_handle: StopHandle, + /// Max subscriptions per connection. + pub(crate) max_subscriptions_per_connection: u32, + /// Connection ID + pub(crate) conn_id: u32, + /// Logger. + pub(crate) logger: L, + /// Handle to hold a `connection permit`. + pub(crate) conn: Arc, + /// Enable HTTP. + pub(crate) enable_http: bool, + /// Enable WS. + pub(crate) enable_ws: bool, +} + +/// JsonRPSee service compatible with `tower`. +/// +/// # Note +/// This is similar to [`hyper::service::service_fn`]. +#[derive(Debug, Clone)] +pub struct TowerService { + inner: ServiceData, +} + +impl hyper::service::Service> for TowerService { + type Response = hyper::Response; + + // The following associated type is required by the `impl Server` bounds. + // It satisfies the server's bounds when the `tower::ServiceBuilder` is not set (ie `B: Identity`). + type Error = Box; + + type Future = Pin> + Send>>; + + /// Opens door for back pressure implementation. + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: hyper::Request) -> Self::Future { + tracing::trace!("{:?}", request); + + let host = match http_helpers::read_header_value(request.headers(), hyper::header::HOST) { + Some(host) => host, + None if request.version() == hyper::Version::HTTP_2 => match request.uri().host() { + Some(host) => host, + None => return async move { Ok(http::response::malformed()) }.boxed(), + }, + None => return async move { Ok(http::response::malformed()) }.boxed(), + }; + + if let Err(e) = self.inner.allow_hosts.verify(host) { + tracing::warn!("Denied request: {}", e); + return async { Ok(http::response::host_not_allowed()) }.boxed(); + } + + let is_upgrade_request = is_upgrade_request(&request); + + if self.inner.enable_ws && is_upgrade_request { + let mut server = soketto::handshake::http::Server::new(); + + let response = match server.receive_request(&request) { + Ok(response) => { + self.inner.logger.on_connect( + self.inner.remote_addr, + &request, + TransportProtocol::WebSocket, + ); + let data = self.inner.clone(); + + tokio::spawn( + async move { + let upgraded = match hyper::upgrade::on(request).await { + Ok(u) => u, + Err(e) => { + tracing::warn!("Could not upgrade connection: {}", e); + return; + } + }; + + let stream = BufReader::new(BufWriter::new(upgraded.compat())); + let mut ws_builder = server.into_builder(stream); + ws_builder.set_max_message_size(data.max_request_body_size as usize); + let (sender, receiver) = ws_builder.finish(); + + let _ = ws::background_task::(sender, receiver, data).await; + } + .in_current_span(), + ); + + response.map(|()| hyper::Body::empty()) + } + Err(e) => { + tracing::error!("Could not upgrade connection: {}", e); + hyper::Response::new(hyper::Body::from(format!( + "Could not upgrade connection: {}", + e + ))) + } + }; + + async { Ok(response) }.boxed() + } else if self.inner.enable_http && !is_upgrade_request { + // The request wasn't an upgrade request; let's treat it as a standard HTTP request: + let data = http::HandleRequest { + methods: self.inner.methods.clone(), + resources: self.inner.resources.clone(), + max_request_body_size: self.inner.max_request_body_size, + max_response_body_size: self.inner.max_response_body_size, + max_log_length: self.inner.max_log_length, + batch_requests_supported: self.inner.batch_requests_supported, + logger: self.inner.logger.clone(), + conn: self.inner.conn.clone(), + remote_addr: self.inner.remote_addr, + }; + + self.inner.logger.on_connect(self.inner.remote_addr, &request, TransportProtocol::Http); + + Box::pin(http::handle_request(request, data).map(Ok)) + } else { + Box::pin(async { http::response::denied() }.map(Ok)) + } + } +} + +/// This is a glorified select listening for new messages, while also checking the `stop_receiver` signal. +struct Monitored<'a, F> { + future: F, + stop_monitor: &'a StopHandle, +} + +impl<'a, F> Monitored<'a, F> { + fn new(future: F, stop_monitor: &'a StopHandle) -> Self { + Monitored { future, stop_monitor } + } +} + +enum MonitoredError { + Shutdown, + Selector(E), +} + +struct Incoming(TcpListener); + +impl<'a> Future for Monitored<'a, Incoming> { + type Output = Result<(TcpStream, SocketAddr), MonitoredError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = Pin::into_inner(self); + + if this.stop_monitor.shutdown_requested() { + return Poll::Ready(Err(MonitoredError::Shutdown)); + } + + this.future.0.poll_accept(cx).map_err(MonitoredError::Selector) + } +} + +impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>> +where + F: Future>, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = Pin::into_inner(self); + + if this.stop_monitor.shutdown_requested() { + return Poll::Ready(Err(MonitoredError::Shutdown)); + } + + this.future.poll_unpin(cx).map_err(MonitoredError::Selector) + } +} + +struct ProcessConnection { + /// Remote server address. + remote_addr: SocketAddr, + /// Registered server methods. + methods: Methods, + /// Access control. + allow_hosts: AllowHosts, + /// Tracker for currently used resources on the server. + resources: Resources, + /// Max request body size. + max_request_body_size: u32, + /// Max response body size. + max_response_body_size: u32, + /// Max length for logging for request and response + /// + /// Logs bigger than this limit will be truncated. + max_log_length: u32, + /// Whether batch requests are supported by this server or not. + batch_requests_supported: bool, + /// Subscription ID provider. + id_provider: Arc, + /// Ping interval + ping_interval: Duration, + /// Stop handle. + stop_handle: StopHandle, + /// Max subscriptions per connection. + max_subscriptions_per_connection: u32, + /// Max connections, + max_connections: u32, + /// Connection ID + conn_id: u32, + /// Logger. + logger: L, + /// Allow JSON-RPC HTTP requests. + enable_http: bool, + /// Allow JSON-RPC WS request and WS upgrade requests. + enable_ws: bool, +} + +#[instrument(name = "connection", skip_all, fields(remote_addr = %cfg.remote_addr, conn_id = %cfg.conn_id), level = "INFO")] +fn process_connection<'a, L: Logger, B, U>( + service_builder: &tower::ServiceBuilder, + connection_guard: &ConnectionGuard, + cfg: ProcessConnection, + socket: TcpStream, + connections: &mut FutureDriver>, +) where + B: Layer> + Send + 'static, + >>::Service: Send + + Service< + hyper::Request, + Response = hyper::Response, + Error = Box<(dyn StdError + Send + Sync + 'static)>, + >, + <>>::Service as Service>>::Future: Send, + U: HttpBody + Send + 'static, + ::Error: Send + Sync + StdError, + ::Data: Send, +{ + if let Err(e) = socket.set_nodelay(true) { + tracing::warn!("Could not set NODELAY on socket: {:?}", e); + return; + } + + let conn = match connection_guard.try_acquire() { + Some(conn) => conn, + None => { + tracing::warn!("Too many connections. Please try again later."); + connections.add(http::reject_connection(socket).in_current_span().boxed()); + return; + } + }; + + let max_conns = cfg.max_connections as usize; + let curr_conns = max_conns - connection_guard.available_connections(); + tracing::info!("Accepting new connection {}/{}", curr_conns, max_conns); + + let tower_service = TowerService { + inner: ServiceData { + remote_addr: cfg.remote_addr, + methods: cfg.methods, + allow_hosts: cfg.allow_hosts, + resources: cfg.resources, + max_request_body_size: cfg.max_request_body_size, + max_response_body_size: cfg.max_response_body_size, + max_log_length: cfg.max_log_length, + batch_requests_supported: cfg.batch_requests_supported, + id_provider: cfg.id_provider, + ping_interval: cfg.ping_interval, + stop_handle: cfg.stop_handle.clone(), + max_subscriptions_per_connection: cfg.max_subscriptions_per_connection, + conn_id: cfg.conn_id, + logger: cfg.logger, + conn: Arc::new(conn), + enable_http: cfg.enable_http, + enable_ws: cfg.enable_ws, + }, + }; + + let service = service_builder.service(tower_service); + + connections + .add(Box::pin(try_accept_connection(socket, service, cfg.stop_handle).in_current_span())); +} + +// Attempts to create a HTTP connection from a socket. +async fn try_accept_connection(socket: TcpStream, service: S, mut stop_handle: StopHandle) +where + S: Service, Response = hyper::Response> + Send + 'static, + S::Error: Into>, + S::Future: Send, + Bd: HttpBody + Send + 'static, + ::Error: Send + Sync + StdError, + ::Data: Send, +{ + let conn = hyper::server::conn::Http::new().serve_connection(socket, service).with_upgrades(); + + tokio::pin!(conn); + + tokio::select! { + res = &mut conn => { + if let Err(e) = res { + tracing::warn!("HTTP serve connection failed {:?}", e); + } + } + _ = stop_handle.shutdown() => { + conn.graceful_shutdown(); + } + } +} diff --git a/crates/katana/rpc/rpc/src/transport/http.rs b/crates/katana/rpc/rpc/src/transport/http.rs new file mode 100644 index 0000000000..ccacb8a463 --- /dev/null +++ b/crates/katana/rpc/rpc/src/transport/http.rs @@ -0,0 +1,502 @@ +use std::convert::Infallible; +use std::net::SocketAddr; +use std::sync::Arc; + +use crate::logger::{self, Logger, TransportProtocol}; + +use futures_util::future::Either; +use futures_util::stream::{FuturesOrdered, StreamExt}; +use http::Method; +use jsonrpsee_core::error::GenericTransportError; +use jsonrpsee_core::http_helpers::read_body; +use jsonrpsee_core::server::helpers::{ + prepare_error, BatchResponse, BatchResponseBuilder, MethodResponse, +}; +use jsonrpsee_core::server::rpc_module::MethodKind; +use jsonrpsee_core::server::{resource_limiting::Resources, rpc_module::Methods}; +use jsonrpsee_core::tracing::{rx_log_from_json, tx_log_from_str}; +use jsonrpsee_core::JsonRawValue; +use jsonrpsee_types::error::{ErrorCode, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG}; +use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Params, Request}; +use tokio::sync::OwnedSemaphorePermit; +use tracing::instrument; + +type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>; + +/// Checks that content type of received request is valid for JSON-RPC. +pub(crate) fn content_type_is_json(request: &hyper::Request) -> bool { + is_json(request.headers().get(http::header::CONTENT_TYPE)) +} + +/// Returns true if the `content_type` header indicates a valid JSON message. +pub(crate) fn is_json(content_type: Option<&hyper::header::HeaderValue>) -> bool { + content_type.and_then(|val| val.to_str().ok()).map_or(false, |content| { + content.eq_ignore_ascii_case("application/json") + || content.eq_ignore_ascii_case("application/json; charset=utf-8") + || content.eq_ignore_ascii_case("application/json;charset=utf-8") + }) +} + +pub(crate) async fn reject_connection(socket: tokio::net::TcpStream) { + async fn reject( + _req: hyper::Request, + ) -> Result, Infallible> { + Ok(response::too_many_requests()) + } + + if let Err(e) = hyper::server::conn::Http::new() + .serve_connection(socket, hyper::service::service_fn(reject)) + .await + { + tracing::warn!("Error when trying to deny connection: {:?}", e); + } +} + +#[derive(Debug)] +pub(crate) struct ProcessValidatedRequest<'a, L: Logger> { + pub(crate) request: hyper::Request, + pub(crate) logger: &'a L, + pub(crate) methods: Methods, + pub(crate) resources: Resources, + pub(crate) max_request_body_size: u32, + pub(crate) max_response_body_size: u32, + pub(crate) max_log_length: u32, + pub(crate) batch_requests_supported: bool, + pub(crate) request_start: L::Instant, +} + +/// Process a verified request, it implies a POST request with content type JSON. +pub(crate) async fn process_validated_request( + input: ProcessValidatedRequest<'_, L>, +) -> hyper::Response { + let ProcessValidatedRequest { + request, + logger, + methods, + resources, + max_request_body_size, + max_response_body_size, + max_log_length, + batch_requests_supported, + request_start, + } = input; + + let (parts, body) = request.into_parts(); + + let (body, is_single) = match read_body(&parts.headers, body, max_request_body_size).await { + Ok(r) => r, + Err(GenericTransportError::TooLarge) => return response::too_large(max_request_body_size), + Err(GenericTransportError::Malformed) => return response::malformed(), + Err(GenericTransportError::Inner(e)) => { + tracing::error!("Internal error reading request body: {}", e); + return response::internal_error(); + } + }; + + // Single request or notification + if is_single { + let call = CallData { + conn_id: 0, + logger, + methods: &methods, + max_response_body_size, + max_log_length, + resources: &resources, + request_start, + }; + let response = process_single_request(body, call).await; + logger.on_response(&response.result, request_start, TransportProtocol::Http); + response::ok_response(response.result) + } + // Batch of requests or notifications + else if !batch_requests_supported { + let err = MethodResponse::error( + Id::Null, + ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None), + ); + logger.on_response(&err.result, request_start, TransportProtocol::Http); + response::ok_response(err.result) + } + // Batch of requests or notifications + else { + let response = process_batch_request(Batch { + data: body, + call: CallData { + conn_id: 0, + logger, + methods: &methods, + max_response_body_size, + max_log_length, + resources: &resources, + request_start, + }, + }) + .await; + logger.on_response(&response.result, request_start, TransportProtocol::Http); + response::ok_response(response.result) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Batch<'a, L: Logger> { + data: Vec, + call: CallData<'a, L>, +} + +#[derive(Debug, Clone)] +pub(crate) struct CallData<'a, L: Logger> { + conn_id: usize, + logger: &'a L, + methods: &'a Methods, + max_response_body_size: u32, + max_log_length: u32, + resources: &'a Resources, + request_start: L::Instant, +} + +// Batch responses must be sent back as a single message so we read the results from each +// request in the batch and read the results off of a new channel, `rx_batch`, and then send the +// complete batch response back to the client over `tx`. +#[instrument(name = "batch", skip(b), level = "TRACE")] +pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> BatchResponse +where + L: Logger, +{ + let Batch { data, call } = b; + + if let Ok(batch) = serde_json::from_slice::>(&data) { + let mut got_notif = false; + let mut batch_response = + BatchResponseBuilder::new_with_limit(call.max_response_body_size as usize); + + let mut pending_calls: FuturesOrdered<_> = batch + .into_iter() + .filter_map(|v| { + if let Ok(req) = serde_json::from_str::(v.get()) { + Some(Either::Right(execute_call(req, call.clone()))) + } else if let Ok(_notif) = + serde_json::from_str::>(v.get()) + { + // notifications should not be answered. + got_notif = true; + None + } else { + // valid JSON but could be not parsable as `InvalidRequest` + let id = match serde_json::from_str::(v.get()) { + Ok(err) => err.id, + Err(_) => Id::Null, + }; + + Some(Either::Left(async { + MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)) + })) + } + }) + .collect(); + + while let Some(response) = pending_calls.next().await { + if let Err(too_large) = batch_response.append(&response) { + return too_large; + } + } + + if got_notif && batch_response.is_empty() { + BatchResponse { result: String::new(), success: true } + } else { + batch_response.finish() + } + } else { + BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError)) + } +} + +pub(crate) async fn process_single_request( + data: Vec, + call: CallData<'_, L>, +) -> MethodResponse { + if let Ok(req) = serde_json::from_slice::(&data) { + execute_call_with_tracing(req, call).await + } else if let Ok(notif) = serde_json::from_slice::(&data) { + execute_notification(notif, call.max_log_length) + } else { + let (id, code) = prepare_error(&data); + MethodResponse::error(id, ErrorObject::from(code)) + } +} + +#[instrument(name = "method_call", fields(method = req.method.as_ref()), skip(call, req), level = "TRACE")] +pub(crate) async fn execute_call_with_tracing<'a, L: Logger>( + req: Request<'a>, + call: CallData<'_, L>, +) -> MethodResponse { + execute_call(req, call).await +} + +pub(crate) async fn execute_call( + req: Request<'_>, + call: CallData<'_, L>, +) -> MethodResponse { + let CallData { + resources, + methods, + logger, + max_response_body_size, + max_log_length, + conn_id, + request_start, + } = call; + + rx_log_from_json(&req, call.max_log_length); + + let params = Params::new(req.params.map(|params| params.get())); + let name = &req.method; + let id = req.id; + + let response = match methods.method_with_name(name) { + None => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::Unknown, + TransportProtocol::Http, + ); + MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)) + } + Some((name, method)) => match &method.inner() { + MethodKind::Sync(callback) => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::MethodCall, + TransportProtocol::Http, + ); + + match method.claim(name, resources) { + Ok(guard) => { + let r = (callback)(id, params, max_response_body_size as usize); + drop(guard); + r + } + Err(err) => { + tracing::error!( + "[Methods::execute_with_resources] failed to lock resources: {}", + err + ); + MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)) + } + } + } + MethodKind::Async(callback) => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::MethodCall, + TransportProtocol::Http, + ); + match method.claim(name, resources) { + Ok(guard) => { + let id = id.into_owned(); + let params = params.into_owned(); + + (callback)( + id, + params, + conn_id, + max_response_body_size as usize, + Some(guard), + ) + .await + } + Err(err) => { + tracing::error!( + "[Methods::execute_with_resources] failed to lock resources: {}", + err + ); + MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)) + } + } + } + MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::Unknown, + TransportProtocol::Http, + ); + tracing::error!("Subscriptions not supported on HTTP"); + MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)) + } + }, + }; + + tx_log_from_str(&response.result, max_log_length); + logger.on_result(name, response.success, request_start, TransportProtocol::Http); + response +} + +#[instrument(name = "notification", fields(method = notif.method.as_ref()), skip(notif, max_log_length), level = "TRACE")] +fn execute_notification(notif: Notif, max_log_length: u32) -> MethodResponse { + rx_log_from_json(¬if, max_log_length); + let response = MethodResponse { result: String::new(), success: true }; + tx_log_from_str(&response.result, max_log_length); + response +} + +pub(crate) struct HandleRequest { + pub(crate) methods: Methods, + pub(crate) resources: Resources, + pub(crate) max_request_body_size: u32, + pub(crate) max_response_body_size: u32, + pub(crate) max_log_length: u32, + pub(crate) batch_requests_supported: bool, + pub(crate) logger: L, + pub(crate) conn: Arc, + pub(crate) remote_addr: SocketAddr, +} + +pub(crate) async fn handle_request( + request: hyper::Request, + input: HandleRequest, +) -> hyper::Response { + let HandleRequest { + methods, + resources, + max_request_body_size, + max_response_body_size, + max_log_length, + batch_requests_supported, + logger, + conn, + remote_addr, + } = input; + + let request_start = logger.on_request(TransportProtocol::Http); + + // Only the `POST` method is allowed. + let res = match *request.method() { + Method::POST if content_type_is_json(&request) => { + process_validated_request(ProcessValidatedRequest { + request, + methods, + resources, + max_request_body_size, + max_response_body_size, + max_log_length, + batch_requests_supported, + logger: &logger, + request_start, + }) + .await + } + // Error scenarios: + Method::POST => response::unsupported_content_type(), + _ => response::method_not_allowed(), + }; + + drop(conn); + logger.on_disconnect(remote_addr, TransportProtocol::Http); + + res +} + +pub(crate) mod response { + use jsonrpsee_types::error::reject_too_big_request; + use jsonrpsee_types::error::{ErrorCode, ErrorResponse}; + use jsonrpsee_types::Id; + + const JSON: &str = "application/json; charset=utf-8"; + const TEXT: &str = "text/plain"; + + /// Create a response for json internal error. + pub(crate) fn internal_error() -> hyper::Response { + let error = serde_json::to_string(&ErrorResponse::borrowed( + ErrorCode::InternalError.into(), + Id::Null, + )) + .expect("built from known-good data; qed"); + + from_template(hyper::StatusCode::INTERNAL_SERVER_ERROR, error, JSON) + } + + /// Create a text/plain response for not allowed hosts. + pub(crate) fn host_not_allowed() -> hyper::Response { + from_template( + hyper::StatusCode::FORBIDDEN, + "Provided Host header is not whitelisted.\n".to_owned(), + TEXT, + ) + } + + /// Create a text/plain response for disallowed method used. + pub(crate) fn method_not_allowed() -> hyper::Response { + from_template( + hyper::StatusCode::METHOD_NOT_ALLOWED, + "Used HTTP Method is not allowed. POST or OPTIONS is required\n".to_owned(), + TEXT, + ) + } + + /// Create a json response for oversized requests (413) + pub(crate) fn too_large(limit: u32) -> hyper::Response { + let error = serde_json::to_string(&ErrorResponse::borrowed( + reject_too_big_request(limit), + Id::Null, + )) + .expect("built from known-good data; qed"); + + from_template(hyper::StatusCode::PAYLOAD_TOO_LARGE, error, JSON) + } + + /// Create a json response for empty or malformed requests (400) + pub(crate) fn malformed() -> hyper::Response { + let error = + serde_json::to_string(&ErrorResponse::borrowed(ErrorCode::ParseError.into(), Id::Null)) + .expect("built from known-good data; qed"); + + from_template(hyper::StatusCode::BAD_REQUEST, error, JSON) + } + + /// Create a response body. + fn from_template>( + status: hyper::StatusCode, + body: S, + content_type: &'static str, + ) -> hyper::Response { + hyper::Response::builder() + .status(status) + .header("content-type", hyper::header::HeaderValue::from_static(content_type)) + .body(body.into()) + // Parsing `StatusCode` and `HeaderValue` is infalliable but + // parsing body content is not. + .expect("Unable to parse response body for type conversion") + } + + /// Create a valid JSON response. + pub(crate) fn ok_response(body: String) -> hyper::Response { + from_template(hyper::StatusCode::OK, body, JSON) + } + + /// Create a response for unsupported content type. + pub(crate) fn unsupported_content_type() -> hyper::Response { + from_template( + hyper::StatusCode::UNSUPPORTED_MEDIA_TYPE, + "Supplied content type is not allowed. Content-Type: application/json is required\n" + .to_owned(), + TEXT, + ) + } + + /// Create a response for when the server is busy and can't accept more requests. + pub(crate) fn too_many_requests() -> hyper::Response { + from_template( + hyper::StatusCode::TOO_MANY_REQUESTS, + "Too many connections. Please try again later.".to_owned(), + TEXT, + ) + } + + /// Create a response for when the server denied the request. + pub(crate) fn denied() -> hyper::Response { + from_template(hyper::StatusCode::FORBIDDEN, "".to_owned(), TEXT) + } +} diff --git a/crates/katana/rpc/rpc/src/transport/mod.rs b/crates/katana/rpc/rpc/src/transport/mod.rs new file mode 100644 index 0000000000..8b293e269f --- /dev/null +++ b/crates/katana/rpc/rpc/src/transport/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod http; +pub(crate) mod ws; diff --git a/crates/katana/rpc/rpc/src/transport/ws.rs b/crates/katana/rpc/rpc/src/transport/ws.rs new file mode 100644 index 0000000000..a6552a8d6c --- /dev/null +++ b/crates/katana/rpc/rpc/src/transport/ws.rs @@ -0,0 +1,613 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use crate::future::{FutureDriver, StopHandle}; +use crate::logger::{self, Logger, TransportProtocol}; +use crate::server::{MethodResult, ServiceData}; + +use futures_channel::mpsc; +use futures_util::future::{self, Either}; +use futures_util::io::{BufReader, BufWriter}; +use futures_util::stream::FuturesOrdered; +use futures_util::{Future, FutureExt, StreamExt}; +use hyper::upgrade::Upgraded; +use jsonrpsee_core::server::helpers::{ + prepare_error, BatchResponse, BatchResponseBuilder, BoundedSubscriptions, MethodResponse, + MethodSink, +}; +use jsonrpsee_core::server::resource_limiting::Resources; +use jsonrpsee_core::server::rpc_module::{ConnState, MethodKind, Methods}; +use jsonrpsee_core::tracing::{rx_log_from_json, tx_log_from_str}; +use jsonrpsee_core::traits::IdProvider; +use jsonrpsee_core::{Error, JsonRawValue}; +use jsonrpsee_types::error::{ + reject_too_big_request, reject_too_many_subscriptions, ErrorCode, BATCHES_NOT_SUPPORTED_CODE, + BATCHES_NOT_SUPPORTED_MSG, +}; +use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Params, Request}; +use soketto::connection::Error as SokettoError; +use soketto::data::ByteSlice125; +use tokio_stream::wrappers::IntervalStream; +use tokio_util::compat::Compat; +use tracing::instrument; + +pub(crate) type Sender = soketto::Sender>>>; +pub(crate) type Receiver = soketto::Receiver>>>; + +pub(crate) async fn send_message(sender: &mut Sender, response: String) -> Result<(), Error> { + sender.send_text_owned(response).await?; + sender.flush().await.map_err(Into::into) +} + +pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), Error> { + tracing::debug!("Send ping"); + // Submit empty slice as "optional" parameter. + let slice: &[u8] = &[]; + // Byte slice fails if the provided slice is larger than 125 bytes. + let byte_slice = + ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125"); + sender.send_ping(byte_slice).await?; + sender.flush().await.map_err(Into::into) +} + +#[derive(Debug, Clone)] +pub(crate) struct Batch<'a, L: Logger> { + pub(crate) data: Vec, + pub(crate) call: CallData<'a, L>, +} + +#[derive(Debug, Clone)] +pub(crate) struct CallData<'a, L: Logger> { + pub(crate) conn_id: usize, + pub(crate) bounded_subscriptions: BoundedSubscriptions, + pub(crate) id_provider: &'a dyn IdProvider, + pub(crate) methods: &'a Methods, + pub(crate) max_response_body_size: u32, + pub(crate) max_log_length: u32, + pub(crate) resources: &'a Resources, + pub(crate) sink: &'a MethodSink, + pub(crate) logger: &'a L, + pub(crate) request_start: L::Instant, +} + +/// This is a glorified select listening for new messages, while also checking the `stop_receiver` signal. +struct Monitored<'a, F> { + future: F, + stop_monitor: &'a StopHandle, +} + +impl<'a, F> Monitored<'a, F> { + fn new(future: F, stop_monitor: &'a StopHandle) -> Self { + Monitored { future, stop_monitor } + } +} + +enum MonitoredError { + Shutdown, + Selector(E), +} + +impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>> +where + F: Future>, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = Pin::into_inner(self); + + if this.stop_monitor.shutdown_requested() { + return Poll::Ready(Err(MonitoredError::Shutdown)); + } + + this.future.poll_unpin(cx).map_err(MonitoredError::Selector) + } +} + +// Batch responses must be sent back as a single message so we read the results from each +// request in the batch and read the results off of a new channel, `rx_batch`, and then send the +// complete batch response back to the client over `tx`. +#[instrument(name = "batch", skip(b), level = "TRACE")] +pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> Option { + let Batch { data, call } = b; + + if let Ok(batch) = serde_json::from_slice::>(&data) { + let mut got_notif = false; + let mut batch_response = + BatchResponseBuilder::new_with_limit(call.max_response_body_size as usize); + + let mut pending_calls: FuturesOrdered<_> = batch + .into_iter() + .filter_map(|v| { + if let Ok(req) = serde_json::from_str::(v.get()) { + Some(Either::Right(async { + execute_call(req, call.clone()).await.into_inner() + })) + } else if let Ok(_notif) = + serde_json::from_str::>(v.get()) + { + // notifications should not be answered. + got_notif = true; + None + } else { + // valid JSON but could be not parsable as `InvalidRequest` + let id = match serde_json::from_str::(v.get()) { + Ok(err) => err.id, + Err(_) => Id::Null, + }; + + Some(Either::Left(async { + MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)) + })) + } + }) + .collect(); + + while let Some(response) = pending_calls.next().await { + if let Err(too_large) = batch_response.append(&response) { + return Some(too_large); + } + } + + if got_notif && batch_response.is_empty() { + None + } else { + Some(batch_response.finish()) + } + } else { + Some(BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError))) + } +} + +pub(crate) async fn process_single_request( + data: Vec, + call: CallData<'_, L>, +) -> MethodResult { + if let Ok(req) = serde_json::from_slice::(&data) { + execute_call_with_tracing(req, call).await + } else { + let (id, code) = prepare_error(&data); + MethodResult::SendAndLogger(MethodResponse::error(id, ErrorObject::from(code))) + } +} + +#[instrument(name = "method_call", fields(method = req.method.as_ref()), skip(call, req), level = "TRACE")] +pub(crate) async fn execute_call_with_tracing<'a, L: Logger>( + req: Request<'a>, + call: CallData<'_, L>, +) -> MethodResult { + execute_call(req, call).await +} + +/// Execute a call which returns result of the call with a additional sink +/// to fire a signal once the subscription call has been answered. +/// +/// Returns `(MethodResponse, None)` on every call that isn't a subscription +/// Otherwise `(MethodResponse, Some(PendingSubscriptionCallTx)`. +pub(crate) async fn execute_call<'a, L: Logger>( + req: Request<'a>, + call: CallData<'_, L>, +) -> MethodResult { + let CallData { + resources, + methods, + max_response_body_size, + max_log_length, + conn_id, + bounded_subscriptions, + id_provider, + sink, + logger, + request_start, + } = call; + + rx_log_from_json(&req, call.max_log_length); + + let params = Params::new(req.params.map(|params| params.get())); + let name = &req.method; + let id = req.id; + + let response = match methods.method_with_name(name) { + None => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::Unknown, + TransportProtocol::WebSocket, + ); + let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)); + MethodResult::SendAndLogger(response) + } + Some((name, method)) => match &method.inner() { + MethodKind::Sync(callback) => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::MethodCall, + TransportProtocol::WebSocket, + ); + match method.claim(name, resources) { + Ok(guard) => { + let r = (callback)(id, params, max_response_body_size as usize); + drop(guard); + MethodResult::SendAndLogger(r) + } + Err(err) => { + tracing::error!( + "[Methods::execute_with_resources] failed to lock resources: {}", + err + ); + let response = + MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); + MethodResult::SendAndLogger(response) + } + } + } + MethodKind::Async(callback) => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::MethodCall, + TransportProtocol::WebSocket, + ); + match method.claim(name, resources) { + Ok(guard) => { + let id = id.into_owned(); + let params = params.into_owned(); + + let response = (callback)( + id, + params, + conn_id, + max_response_body_size as usize, + Some(guard), + ) + .await; + MethodResult::SendAndLogger(response) + } + Err(err) => { + tracing::error!( + "[Methods::execute_with_resources] failed to lock resources: {}", + err + ); + let response = + MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); + MethodResult::SendAndLogger(response) + } + } + } + MethodKind::Subscription(callback) => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::Subscription, + TransportProtocol::WebSocket, + ); + match method.claim(name, resources) { + Ok(guard) => { + if let Some(cn) = bounded_subscriptions.acquire() { + let conn_state = ConnState { conn_id, close_notify: cn, id_provider }; + let response = + callback(id.clone(), params, sink.clone(), conn_state, Some(guard)) + .await; + MethodResult::JustLogger(response) + } else { + let response = MethodResponse::error( + id, + reject_too_many_subscriptions(bounded_subscriptions.max()), + ); + MethodResult::SendAndLogger(response) + } + } + Err(err) => { + tracing::error!( + "[Methods::execute_with_resources] failed to lock resources: {}", + err + ); + let response = + MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); + MethodResult::SendAndLogger(response) + } + } + } + MethodKind::Unsubscription(callback) => { + logger.on_call( + name, + params.clone(), + logger::MethodKind::Unsubscription, + TransportProtocol::WebSocket, + ); + + // Don't adhere to any resource or subscription limits; always let unsubscribing happen! + let result = callback(id, params, conn_id, max_response_body_size as usize); + MethodResult::SendAndLogger(result) + } + }, + }; + + let r = response.as_inner(); + + tx_log_from_str(&r.result, max_log_length); + logger.on_result(name, r.success, request_start, TransportProtocol::WebSocket); + response +} + +pub(crate) async fn background_task( + sender: Sender, + mut receiver: Receiver, + svc: ServiceData, +) -> Result<(), Error> { + let ServiceData { + methods, + resources, + max_request_body_size, + max_response_body_size, + max_log_length, + batch_requests_supported, + stop_handle, + id_provider, + ping_interval, + max_subscriptions_per_connection, + conn_id, + logger, + remote_addr, + conn, + .. + } = svc; + + let (tx, rx) = mpsc::unbounded::(); + let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection); + let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length); + + // Spawn another task that sends out the responses on the Websocket. + tokio::spawn(send_task(rx, sender, stop_handle.clone(), ping_interval)); + + // Buffer for incoming data. + let mut data = Vec::with_capacity(100); + let mut method_executors = FutureDriver::default(); + let logger = &logger; + + let result = loop { + data.clear(); + + { + // Need the extra scope to drop this pinned future and reclaim access to `data` + let receive = async { + // Identical loop to `soketto::receive_data` with debug logs for `Pong` frames. + loop { + match receiver.receive(&mut data).await? { + soketto::Incoming::Data(d) => break Ok(d), + soketto::Incoming::Pong(_) => tracing::debug!("Received pong"), + soketto::Incoming::Closed(_) => { + // The closing reason is already logged by `soketto` trace log level. + // Return the `Closed` error to avoid logging unnecessary warnings on clean shutdown. + break Err(SokettoError::Closed); + } + } + } + }; + + tokio::pin!(receive); + + if let Err(err) = + method_executors.select_with(Monitored::new(receive, &stop_handle)).await + { + match err { + MonitoredError::Selector(SokettoError::Closed) => { + tracing::debug!( + "WS transport: remote peer terminated the connection: {}", + conn_id + ); + break Ok(()); + } + MonitoredError::Selector(SokettoError::MessageTooLarge { + current, + maximum, + }) => { + tracing::warn!( + "WS transport error: request length: {} exceeded max limit: {} bytes", + current, + maximum + ); + sink.send_error(Id::Null, reject_too_big_request(max_request_body_size)); + continue; + } + + // These errors can not be gracefully handled, so just log them and terminate the connection. + MonitoredError::Selector(err) => { + tracing::error!( + "WS transport error: {}; terminate connection: {}", + err, + conn_id + ); + break Err(err.into()); + } + MonitoredError::Shutdown => { + break Ok(()); + } + }; + }; + }; + + let request_start = logger.on_request(TransportProtocol::WebSocket); + + let first_non_whitespace = data.iter().find(|byte| !byte.is_ascii_whitespace()); + match first_non_whitespace { + Some(b'{') => { + let data = std::mem::take(&mut data); + let sink = sink.clone(); + let resources = &resources; + let methods = &methods; + let bounded_subscriptions = bounded_subscriptions.clone(); + let id_provider = &*id_provider; + + let fut = async move { + let call = CallData { + conn_id: conn_id as usize, + resources, + max_response_body_size, + max_log_length, + methods, + bounded_subscriptions, + sink: &sink, + id_provider, + logger, + request_start, + }; + + match process_single_request(data, call).await { + MethodResult::JustLogger(r) => { + logger.on_response( + &r.result, + request_start, + TransportProtocol::WebSocket, + ); + } + MethodResult::SendAndLogger(r) => { + logger.on_response( + &r.result, + request_start, + TransportProtocol::WebSocket, + ); + let _ = sink.send_raw(r.result); + } + }; + } + .boxed(); + + method_executors.add(fut); + } + Some(b'[') if !batch_requests_supported => { + let response = MethodResponse::error( + Id::Null, + ErrorObject::borrowed( + BATCHES_NOT_SUPPORTED_CODE, + &BATCHES_NOT_SUPPORTED_MSG, + None, + ), + ); + logger.on_response(&response.result, request_start, TransportProtocol::WebSocket); + let _ = sink.send_raw(response.result); + } + Some(b'[') => { + // Make sure the following variables are not moved into async closure below. + let resources = &resources; + let methods = &methods; + let bounded_subscriptions = bounded_subscriptions.clone(); + let sink = sink.clone(); + let id_provider = id_provider.clone(); + let data = std::mem::take(&mut data); + + let fut = async move { + let response = process_batch_request(Batch { + data, + call: CallData { + conn_id: conn_id as usize, + resources, + max_response_body_size, + max_log_length, + methods, + bounded_subscriptions, + sink: &sink, + id_provider: &*id_provider, + logger, + request_start, + }, + }) + .await; + + if let Some(response) = response { + tx_log_from_str(&response.result, max_log_length); + logger.on_response( + &response.result, + request_start, + TransportProtocol::WebSocket, + ); + let _ = sink.send_raw(response.result); + } + }; + + method_executors.add(Box::pin(fut)); + } + _ => { + sink.send_error(Id::Null, ErrorCode::ParseError.into()); + } + } + }; + + logger.on_disconnect(remote_addr, TransportProtocol::WebSocket); + + // Drive all running methods to completion. + // **NOTE** Do not return early in this function. This `await` needs to run to guarantee + // proper drop behaviour. + method_executors.await; + + // Notify all listeners and close down associated tasks. + sink.close(); + bounded_subscriptions.close(); + + drop(conn); + + result +} + +/// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`. +async fn send_task( + mut rx: mpsc::UnboundedReceiver, + mut ws_sender: Sender, + mut stop_handle: StopHandle, + ping_interval: Duration, +) { + // Received messages from the WebSocket. + let mut rx_item = rx.next(); + + // Interval to send out continuously `pings`. + let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval)); + let stopped = stop_handle.shutdown(); + + tokio::pin!(ping_interval, stopped); + + let next_ping = ping_interval.next(); + let mut futs = future::select(next_ping, stopped); + + loop { + // Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish yet. + // Note: Although, this is cancel-safe already, avoid using `select!` macro for future proofing. + match future::select(rx_item, futs).await { + // Received message. + Either::Left((Some(response), not_ready)) => { + // If websocket message send fail then terminate the connection. + if let Err(err) = send_message(&mut ws_sender, response).await { + tracing::error!("WS transport error: send failed: {}", err); + break; + } + rx_item = rx.next(); + futs = not_ready; + } + + // Nothing else to receive. + Either::Left((None, _)) => { + break; + } + + // Handle timer intervals. + Either::Right((Either::Left((_, stop)), next_rx)) => { + if let Err(err) = send_ping(&mut ws_sender).await { + tracing::error!("WS transport error: send ping failed: {}", err); + break; + } + rx_item = next_rx; + futs = future::select(ping_interval.next(), stop); + } + + // Server is closed + Either::Right((Either::Right((_, _)), _)) => { + break; + } + } + } + + // Terminate connection and send close message. + let _ = ws_sender.close().await; +} From 39477956474142a05bc8ae8f7d6fd35245f75b00 Mon Sep 17 00:00:00 2001 From: Fabricio Date: Tue, 22 Oct 2024 17:58:10 -0600 Subject: [PATCH 06/21] cargo fmt --- crates/katana/node/src/lib.rs | 4 +- crates/katana/rpc/rpc-api/src/dev.rs | 2 +- crates/katana/rpc/rpc/src/dev.rs | 2 +- crates/katana/rpc/rpc/src/future.rs | 6 +- crates/katana/rpc/rpc/src/lib.rs | 6 +- crates/katana/rpc/rpc/src/logger.rs | 12 +- .../katana/rpc/rpc/src/proxy_get_request.rs | 28 ++-- crates/katana/rpc/rpc/src/server.rs | 138 ++++++++++++------ crates/katana/rpc/rpc/src/transport/http.rs | 11 +- crates/katana/rpc/rpc/src/transport/ws.rs | 31 ++-- 10 files changed, 142 insertions(+), 98 deletions(-) diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index 62024e475a..dd6578b4d6 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -7,8 +7,6 @@ use std::time::Duration; use anyhow::Result; use dojo_metrics::{metrics_process, prometheus_exporter, Report}; use hyper::{Method, Uri}; -// use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; -use katana_rpc::proxy_get_request::DevnetProxyLayer; use jsonrpsee::server::{AllowHosts, ServerBuilder, ServerHandle}; use jsonrpsee::RpcModule; use katana_core::backend::config::StarknetConfig; @@ -34,6 +32,8 @@ use katana_provider::providers::in_memory::InMemoryProvider; use katana_rpc::config::ServerConfig; use katana_rpc::dev::DevApi; use katana_rpc::metrics::RpcServerMetrics; +// use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; +use katana_rpc::proxy_get_request::DevnetProxyLayer; use katana_rpc::saya::SayaApi; use katana_rpc::starknet::StarknetApi; use katana_rpc::torii::ToriiApi; diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index 61889c1007..ad9937f453 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -20,7 +20,7 @@ pub trait DevApi { #[method(name = "setStorageAt")] async fn set_storage_at(&self, contract_address: Felt, key: Felt, value: Felt) - -> RpcResult<()>; + -> RpcResult<()>; #[method(name = "predeployedAccounts")] async fn predeployed_accounts(&self) -> RpcResult>; diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index 82e31736ad..36e5678e1c 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -7,12 +7,12 @@ use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, Pen use katana_executor::ExecutorFactory; use katana_primitives::genesis::constant::ERC20_NAME_STORAGE_SLOT; use katana_primitives::ContractAddress; -use starknet_crypto::Felt; use katana_provider::traits::state::StateFactoryProvider; use katana_rpc_api::dev::DevApiServer; use katana_rpc_types::account::Account; use katana_rpc_types::error::dev::DevApiError; use starknet::core::utils::get_storage_var_address; +use starknet_crypto::Felt; #[allow(missing_debug_implementations)] pub struct DevApi { diff --git a/crates/katana/rpc/rpc/src/future.rs b/crates/katana/rpc/rpc/src/future.rs index 15ec7e29d0..82c28f1ebc 100644 --- a/crates/katana/rpc/rpc/src/future.rs +++ b/crates/katana/rpc/rpc/src/future.rs @@ -113,11 +113,7 @@ where this.drive(cx); - if this.futures.is_empty() { - Poll::Ready(()) - } else { - Poll::Pending - } + if this.futures.is_empty() { Poll::Ready(()) } else { Poll::Pending } } } diff --git a/crates/katana/rpc/rpc/src/lib.rs b/crates/katana/rpc/rpc/src/lib.rs index 5c56b2d0a6..8ca3f8e1ac 100644 --- a/crates/katana/rpc/rpc/src/lib.rs +++ b/crates/katana/rpc/rpc/src/lib.rs @@ -6,13 +6,13 @@ pub mod config; pub mod dev; pub mod metrics; +pub mod proxy_get_request; pub mod saya; pub mod starknet; pub mod torii; -pub mod proxy_get_request; -mod utils; -mod transport; mod future; mod logger; mod server; +mod transport; +mod utils; diff --git a/crates/katana/rpc/rpc/src/logger.rs b/crates/katana/rpc/rpc/src/logger.rs index 53e5eb4477..72fcc48605 100644 --- a/crates/katana/rpc/rpc/src/logger.rs +++ b/crates/katana/rpc/rpc/src/logger.rs @@ -79,11 +79,12 @@ impl std::fmt::Display for TransportProtocol { } } -/// Defines a logger specifically for WebSocket connections with callbacks during the RPC request life-cycle. -/// The primary use case for this is to collect timings for a larger metrics collection solution. +/// Defines a logger specifically for WebSocket connections with callbacks during the RPC request +/// life-cycle. The primary use case for this is to collect timings for a larger metrics collection +/// solution. /// -/// See the [`ServerBuilder::set_logger`](../../jsonrpsee_server/struct.ServerBuilder.html#method.set_logger) -/// for examples. +/// See the [`ServerBuilder::set_logger`](../../jsonrpsee_server/struct.ServerBuilder.html#method. +/// set_logger) for examples. pub trait Logger: Send + Sync + Clone + 'static { /// Intended to carry timestamp of a request, for example `std::time::Instant`. How the trait /// measures time, if at all, is entirely up to the implementation. @@ -104,7 +105,8 @@ pub trait Logger: Send + Sync + Clone + 'static { transport: TransportProtocol, ); - /// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple times. + /// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple + /// times. fn on_result( &self, method_name: &str, diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index 9b55ed9e84..bdf66da9e1 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -1,21 +1,22 @@ //! Middleware that proxies requests at a specified URI to internal //! RPC method calls. -use crate::transport::http; -use hyper::body; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + use hyper::header::{ACCEPT, CONTENT_TYPE}; use hyper::http::HeaderValue; use hyper::{Body, Method, Request, Response, Uri}; use jsonrpsee_core::error::Error as RpcError; use jsonrpsee_core::JsonRawValue; use jsonrpsee_types::{Id, RequestSer}; -use std::error::Error; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; use tower::{Layer, Service}; +use crate::transport::http; + /// Layer that applies [`DevnetProxy`] which proxies the `GET /path` requests to /// specific RPC method calls and that strips the response. /// @@ -33,9 +34,7 @@ impl DevnetProxyLayer { pub fn new(path: impl Into, method: impl Into) -> Result { let path = path.into(); if !path.starts_with('/') { - return Err(RpcError::Custom( - "DevnetProxyLayer path must start with `/`".to_string(), - )); + return Err(RpcError::Custom("DevnetProxyLayer path must start with `/`".to_string())); } Ok(Self { path, method: method.into() }) @@ -102,7 +101,7 @@ where fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx).map_err(Into::into) } - + fn call(&mut self, mut req: Request) -> Self::Future { let modify = self.path.as_ref() == req.uri() && req.method() == Method::GET; @@ -118,7 +117,12 @@ where req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); // Adjust the body to reflect the method call. - let raw_value = JsonRawValue::from_string("{\"address\":\"0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114\", \"age\":5, \"name\":\"somename\"}".to_string()).unwrap(); + let raw_value = JsonRawValue::from_string( + "{\"address\":\"0x6b86e40118f29ebe393a75469b4d926c7a44c2e2681b6d319520b7c1156d114\\ + ", \"age\":5, \"name\":\"somename\"}" + .to_string(), + ) + .unwrap(); let param = Some(raw_value.as_ref()); let body = Body::from( diff --git a/crates/katana/rpc/rpc/src/server.rs b/crates/katana/rpc/rpc/src/server.rs index c17aefe350..8c8fdadd0b 100644 --- a/crates/katana/rpc/rpc/src/server.rs +++ b/crates/katana/rpc/rpc/src/server.rs @@ -32,23 +32,16 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use crate::future::{ConnectionGuard, FutureDriver, ServerHandle, StopHandle}; -use crate::logger::{Logger, TransportProtocol}; -use crate::transport::{http, ws}; - use futures_util::future::{BoxFuture, FutureExt}; use futures_util::io::{BufReader, BufWriter}; - use hyper::body::HttpBody; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; - use jsonrpsee_core::server::helpers::MethodResponse; use jsonrpsee_core::server::host_filtering::AllowHosts; use jsonrpsee_core::server::resource_limiting::Resources; use jsonrpsee_core::server::rpc_module::Methods; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES}; - use soketto::handshake::http::is_upgrade_request; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{watch, OwnedSemaphorePermit}; @@ -57,6 +50,10 @@ use tower::layer::util::Identity; use tower::{Layer, Service}; use tracing::{instrument, Instrument}; +use crate::future::{ConnectionGuard, FutureDriver, ServerHandle, StopHandle}; +use crate::logger::{Logger, TransportProtocol}; +use crate::transport::{http, ws}; + /// Default maximum connections allowed. const MAX_CONNECTIONS: u32 = 100; @@ -105,7 +102,8 @@ where { /// Start responding to connections requests. /// - /// This will run on the tokio runtime until the server is stopped or the `ServerHandle` is dropped. + /// This will run on the tokio runtime until the server is stopped or the `ServerHandle` is + /// dropped. pub fn start(mut self, methods: impl Into) -> Result { let methods = methods.into().initialize_resources(&self.resources)?; let (stop_tx, stop_rx) = watch::channel(()); @@ -289,8 +287,9 @@ impl Builder { /// Register a new resource kind. Errors if `label` is already registered, or if the number of /// registered resources on this server instance would exceed 8. /// - /// See the module documentation for [`resurce_limiting`](../jsonrpsee_utils/server/resource_limiting/index.html#resource-limiting) - /// for details. + /// See the module documentation for + /// [`resurce_limiting`](../jsonrpsee_utils/server/resource_limiting/index.html# + /// resource-limiting) for details. pub fn register_resource( mut self, label: &'static str, @@ -304,9 +303,10 @@ impl Builder { /// Add a logger to the builder [`Logger`](../jsonrpsee_core/logger/trait.Logger.html). /// /// ``` - /// use std::{time::Instant, net::SocketAddr}; + /// use std::net::SocketAddr; + /// use std::time::Instant; /// - /// use jsonrpsee_server::logger::{Logger, HttpRequest, MethodKind, Params, TransportProtocol}; + /// use jsonrpsee_server::logger::{HttpRequest, Logger, MethodKind, Params, TransportProtocol}; /// use jsonrpsee_server::ServerBuilder; /// /// #[derive(Clone)] @@ -315,28 +315,70 @@ impl Builder { /// impl Logger for MyLogger { /// type Instant = Instant; /// - /// fn on_connect(&self, remote_addr: SocketAddr, request: &HttpRequest, transport: TransportProtocol) { - /// println!("[MyLogger::on_call] remote_addr: {:?}, headers: {:?}, transport: {}", remote_addr, request, transport); + /// fn on_connect( + /// &self, + /// remote_addr: SocketAddr, + /// request: &HttpRequest, + /// transport: TransportProtocol, + /// ) { + /// println!( + /// "[MyLogger::on_call] remote_addr: {:?}, headers: {:?}, transport: {}", + /// remote_addr, request, transport + /// ); /// } /// /// fn on_request(&self, transport: TransportProtocol) -> Self::Instant { - /// Instant::now() + /// Instant::now() /// } /// - /// fn on_call(&self, method_name: &str, params: Params, kind: MethodKind, transport: TransportProtocol) { - /// println!("[MyLogger::on_call] method: '{}' params: {:?}, kind: {:?}, transport: {}", method_name, params, kind, transport); + /// fn on_call( + /// &self, + /// method_name: &str, + /// params: Params, + /// kind: MethodKind, + /// transport: TransportProtocol, + /// ) { + /// println!( + /// "[MyLogger::on_call] method: '{}' params: {:?}, kind: {:?}, transport: {}", + /// method_name, params, kind, transport + /// ); /// } /// - /// fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant, transport: TransportProtocol) { - /// println!("[MyLogger::on_result] '{}', worked? {}, time elapsed {:?}, transport: {}", method_name, success, started_at.elapsed(), transport); + /// fn on_result( + /// &self, + /// method_name: &str, + /// success: bool, + /// started_at: Self::Instant, + /// transport: TransportProtocol, + /// ) { + /// println!( + /// "[MyLogger::on_result] '{}', worked? {}, time elapsed {:?}, transport: {}", + /// method_name, + /// success, + /// started_at.elapsed(), + /// transport + /// ); /// } /// - /// fn on_response(&self, result: &str, started_at: Self::Instant, transport: TransportProtocol) { - /// println!("[MyLogger::on_response] result: {}, time elapsed {:?}, transport: {}", result, started_at.elapsed(), transport); + /// fn on_response( + /// &self, + /// result: &str, + /// started_at: Self::Instant, + /// transport: TransportProtocol, + /// ) { + /// println!( + /// "[MyLogger::on_response] result: {}, time elapsed {:?}, transport: {}", + /// result, + /// started_at.elapsed(), + /// transport + /// ); /// } /// /// fn on_disconnect(&self, remote_addr: SocketAddr, transport: TransportProtocol) { - /// println!("[MyLogger::on_disconnect] remote_addr: {:?}, transport: {}", remote_addr, transport); + /// println!( + /// "[MyLogger::on_disconnect] remote_addr: {:?}, transport: {}", + /// remote_addr, transport + /// ); /// } /// } /// @@ -371,6 +413,7 @@ impl Builder { /// /// ```rust /// use std::time::Duration; + /// /// use jsonrpsee_server::ServerBuilder; /// /// // Set the ping interval to 10 seconds. @@ -392,15 +435,15 @@ impl Builder { /// # Examples /// /// ```rust - /// use jsonrpsee_server::{ServerBuilder, RandomStringIdProvider, IdProvider}; + /// use jsonrpsee_server::{IdProvider, RandomStringIdProvider, ServerBuilder}; /// /// // static dispatch /// let builder1 = ServerBuilder::default().set_id_provider(RandomStringIdProvider::new(16)); /// /// // or dynamic dispatch - /// let builder2 = ServerBuilder::default().set_id_provider(Box::new(RandomStringIdProvider::new(16))); + /// let builder2 = + /// ServerBuilder::default().set_id_provider(Box::new(RandomStringIdProvider::new(16))); /// ``` - /// pub fn set_id_provider(mut self, id_provider: I) -> Self { self.id_provider = Arc::new(id_provider); self @@ -412,16 +455,16 @@ impl Builder { self } - /// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied to the RPC service. + /// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied + /// to the RPC service. /// /// Default: No tower layers are applied to the RPC service. /// /// # Examples /// /// ```rust - /// - /// use std::time::Duration; /// use std::net::SocketAddr; + /// use std::time::Duration; /// /// #[tokio::main] /// async fn main() { @@ -469,17 +512,13 @@ impl Builder { /// ```rust /// #[tokio::main] /// async fn main() { - /// let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - /// let occupied_addr = listener.local_addr().unwrap(); - /// let addrs: &[std::net::SocketAddr] = &[ - /// occupied_addr, - /// "127.0.0.1:0".parse().unwrap(), - /// ]; - /// assert!(jsonrpsee_server::ServerBuilder::default().build(occupied_addr).await.is_err()); - /// assert!(jsonrpsee_server::ServerBuilder::default().build(addrs).await.is_ok()); + /// let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + /// let occupied_addr = listener.local_addr().unwrap(); + /// let addrs: &[std::net::SocketAddr] = &[occupied_addr, "127.0.0.1:0".parse().unwrap()]; + /// assert!(jsonrpsee_server::ServerBuilder::default().build(occupied_addr).await.is_err()); + /// assert!(jsonrpsee_server::ServerBuilder::default().build(addrs).await.is_ok()); /// } /// ``` - /// pub async fn build(self, addrs: impl ToSocketAddrs) -> Result, Error> { let listener = TcpListener::bind(addrs).await?; @@ -497,23 +536,24 @@ impl Builder { /// /// /// ```rust + /// use std::time::Duration; + /// /// use jsonrpsee_server::ServerBuilder; /// use socket2::{Domain, Socket, Type}; - /// use std::time::Duration; /// /// #[tokio::main] /// async fn main() { - /// let addr = "127.0.0.1:0".parse().unwrap(); - /// let domain = Domain::for_address(addr); - /// let socket = Socket::new(domain, Type::STREAM, None).unwrap(); - /// socket.set_nonblocking(true).unwrap(); + /// let addr = "127.0.0.1:0".parse().unwrap(); + /// let domain = Domain::for_address(addr); + /// let socket = Socket::new(domain, Type::STREAM, None).unwrap(); + /// socket.set_nonblocking(true).unwrap(); /// - /// let address = addr.into(); - /// socket.bind(&address).unwrap(); + /// let address = addr.into(); + /// socket.bind(&address).unwrap(); /// - /// socket.listen(4096).unwrap(); + /// socket.listen(4096).unwrap(); /// - /// let server = ServerBuilder::new().build_from_tcp(socket).unwrap(); + /// let server = ServerBuilder::new().build_from_tcp(socket).unwrap(); /// } /// ``` pub fn build_from_tcp( @@ -608,7 +648,8 @@ impl hyper::service::Service> for TowerSe type Response = hyper::Response; // The following associated type is required by the `impl Server` bounds. - // It satisfies the server's bounds when the `tower::ServiceBuilder` is not set (ie `B: Identity`). + // It satisfies the server's bounds when the `tower::ServiceBuilder` is not set (ie `B: + // Identity`). type Error = Box; type Future = Pin> + Send>>; @@ -704,7 +745,8 @@ impl hyper::service::Service> for TowerSe } } -/// This is a glorified select listening for new messages, while also checking the `stop_receiver` signal. +/// This is a glorified select listening for new messages, while also checking the `stop_receiver` +/// signal. struct Monitored<'a, F> { future: F, stop_monitor: &'a StopHandle, diff --git a/crates/katana/rpc/rpc/src/transport/http.rs b/crates/katana/rpc/rpc/src/transport/http.rs index ccacb8a463..a081222486 100644 --- a/crates/katana/rpc/rpc/src/transport/http.rs +++ b/crates/katana/rpc/rpc/src/transport/http.rs @@ -2,8 +2,6 @@ use std::convert::Infallible; use std::net::SocketAddr; use std::sync::Arc; -use crate::logger::{self, Logger, TransportProtocol}; - use futures_util::future::Either; use futures_util::stream::{FuturesOrdered, StreamExt}; use http::Method; @@ -12,8 +10,8 @@ use jsonrpsee_core::http_helpers::read_body; use jsonrpsee_core::server::helpers::{ prepare_error, BatchResponse, BatchResponseBuilder, MethodResponse, }; -use jsonrpsee_core::server::rpc_module::MethodKind; -use jsonrpsee_core::server::{resource_limiting::Resources, rpc_module::Methods}; +use jsonrpsee_core::server::resource_limiting::Resources; +use jsonrpsee_core::server::rpc_module::{MethodKind, Methods}; use jsonrpsee_core::tracing::{rx_log_from_json, tx_log_from_str}; use jsonrpsee_core::JsonRawValue; use jsonrpsee_types::error::{ErrorCode, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG}; @@ -21,6 +19,8 @@ use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Params, Req use tokio::sync::OwnedSemaphorePermit; use tracing::instrument; +use crate::logger::{self, Logger, TransportProtocol}; + type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>; /// Checks that content type of received request is valid for JSON-RPC. @@ -400,8 +400,7 @@ pub(crate) async fn handle_request( } pub(crate) mod response { - use jsonrpsee_types::error::reject_too_big_request; - use jsonrpsee_types::error::{ErrorCode, ErrorResponse}; + use jsonrpsee_types::error::{reject_too_big_request, ErrorCode, ErrorResponse}; use jsonrpsee_types::Id; const JSON: &str = "application/json; charset=utf-8"; diff --git a/crates/katana/rpc/rpc/src/transport/ws.rs b/crates/katana/rpc/rpc/src/transport/ws.rs index a6552a8d6c..600384906e 100644 --- a/crates/katana/rpc/rpc/src/transport/ws.rs +++ b/crates/katana/rpc/rpc/src/transport/ws.rs @@ -2,10 +2,6 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use crate::future::{FutureDriver, StopHandle}; -use crate::logger::{self, Logger, TransportProtocol}; -use crate::server::{MethodResult, ServiceData}; - use futures_channel::mpsc; use futures_util::future::{self, Either}; use futures_util::io::{BufReader, BufWriter}; @@ -32,6 +28,10 @@ use tokio_stream::wrappers::IntervalStream; use tokio_util::compat::Compat; use tracing::instrument; +use crate::future::{FutureDriver, StopHandle}; +use crate::logger::{self, Logger, TransportProtocol}; +use crate::server::{MethodResult, ServiceData}; + pub(crate) type Sender = soketto::Sender>>>; pub(crate) type Receiver = soketto::Receiver>>>; @@ -71,7 +71,8 @@ pub(crate) struct CallData<'a, L: Logger> { pub(crate) request_start: L::Instant, } -/// This is a glorified select listening for new messages, while also checking the `stop_receiver` signal. +/// This is a glorified select listening for new messages, while also checking the `stop_receiver` +/// signal. struct Monitored<'a, F> { future: F, stop_monitor: &'a StopHandle, @@ -150,11 +151,7 @@ pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> Option< } } - if got_notif && batch_response.is_empty() { - None - } else { - Some(batch_response.finish()) - } + if got_notif && batch_response.is_empty() { None } else { Some(batch_response.finish()) } } else { Some(BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError))) } @@ -319,7 +316,8 @@ pub(crate) async fn execute_call<'a, L: Logger>( TransportProtocol::WebSocket, ); - // Don't adhere to any resource or subscription limits; always let unsubscribing happen! + // Don't adhere to any resource or subscription limits; always let unsubscribing + // happen! let result = callback(id, params, conn_id, max_response_body_size as usize); MethodResult::SendAndLogger(result) } @@ -381,7 +379,8 @@ pub(crate) async fn background_task( soketto::Incoming::Pong(_) => tracing::debug!("Received pong"), soketto::Incoming::Closed(_) => { // The closing reason is already logged by `soketto` trace log level. - // Return the `Closed` error to avoid logging unnecessary warnings on clean shutdown. + // Return the `Closed` error to avoid logging unnecessary warnings on + // clean shutdown. break Err(SokettoError::Closed); } } @@ -414,7 +413,8 @@ pub(crate) async fn background_task( continue; } - // These errors can not be gracefully handled, so just log them and terminate the connection. + // These errors can not be gracefully handled, so just log them and terminate + // the connection. MonitoredError::Selector(err) => { tracing::error!( "WS transport error: {}; terminate connection: {}", @@ -572,8 +572,9 @@ async fn send_task( let mut futs = future::select(next_ping, stopped); loop { - // Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish yet. - // Note: Although, this is cancel-safe already, avoid using `select!` macro for future proofing. + // Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish + // yet. Note: Although, this is cancel-safe already, avoid using `select!` macro for + // future proofing. match future::select(rx_item, futs).await { // Received message. Either::Left((Some(response), not_ready)) => { From 5a3865c7f87f7d79ed0451135de6332d3bd75263 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Tue, 29 Oct 2024 19:56:17 -0600 Subject: [PATCH 07/21] update cargo.lock --- Cargo.lock | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index d85544f40c..55db8fc3fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8276,8 +8276,14 @@ dependencies = [ "dojo-utils", "dojo-world", "futures", + "futures-channel", + "futures-util", + "http 0.2.12", + "hyper 0.14.30", "indexmap 2.5.0", "jsonrpsee 0.16.3", + "jsonrpsee-core 0.16.3", + "jsonrpsee-types 0.16.3", "katana-cairo", "katana-core", "katana-executor", @@ -8295,10 +8301,15 @@ dependencies = [ "rstest 0.18.2", "serde", "serde_json", + "soketto 0.7.1", "starknet 0.12.0", + "starknet-crypto 0.7.2", "tempfile", "thiserror", "tokio", + "tokio-stream", + "tokio-util", + "tower 0.4.13", "tracing", "url", ] From 8e0886231075111bec2597e8efde5666041eac2c Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Tue, 29 Oct 2024 20:11:53 -0600 Subject: [PATCH 08/21] Keep ProxyGetRequestLayer for / (health) endpoint --- crates/katana/node/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index 44c22a9084..f0f9582df8 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -16,6 +16,7 @@ use config::{Config, SequencingConfig}; use dojo_metrics::exporters::prometheus::PrometheusRecorder; use dojo_metrics::{Report, Server as MetricsServer}; use hyper::{Method, Uri}; +use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; use jsonrpsee::server::{AllowHosts, ServerBuilder, ServerHandle}; use jsonrpsee::RpcModule; use katana_core::backend::storage::Blockchain; @@ -33,7 +34,6 @@ use katana_pool::TxPool; use katana_primitives::env::{CfgEnv, FeeTokenAddressses}; use katana_rpc::dev::DevApi; use katana_rpc::metrics::RpcServerMetrics; -// use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; use katana_rpc::proxy_get_request::DevnetProxyLayer; use katana_rpc::saya::SayaApi; use katana_rpc::starknet::forking::ForkedClient; @@ -308,7 +308,7 @@ pub async fn spawn( let middleware = tower::ServiceBuilder::new() .option_layer(cors) - .layer(DevnetProxyLayer::new("/", "health")?) + .layer(ProxyGetRequestLayer::new("/", "health")?) .layer(DevnetProxyLayer::new("/account_balance", "dev_accountBalance")?) .layer(DevnetProxyLayer::new("/fee_token", "dev_feeToken")?) .timeout(Duration::from_secs(20)); From e4666bc83a7fb61aa07dcd8164fa4616cd860644 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Wed, 30 Oct 2024 21:36:49 -0600 Subject: [PATCH 09/21] Enable query param contract_address for endpoint account_balance --- .../katana/rpc/rpc/src/proxy_get_request.rs | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index df3826fed8..b702ef6c07 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -1,6 +1,7 @@ //! Middleware that proxies requests at a specified URI to internal //! RPC method calls. +use std::collections::HashMap; use std::error::Error; use std::future::Future; use std::pin::Pin; @@ -14,6 +15,7 @@ use jsonrpsee_core::error::Error as RpcError; use jsonrpsee_core::JsonRawValue; use jsonrpsee_types::{Id, RequestSer}; use tower::{Layer, Service}; +use url::form_urlencoded; use crate::transport::http; @@ -103,10 +105,25 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - let modify = self.path.as_ref() == req.uri() && req.method() == Method::GET; + // let modify = self.path.as_ref() == req.uri() && req.method() == Method::GET; + let modify = req.method() == Method::GET; // Proxy the request to the appropriate method call. if modify { + let mut raw_value = None; + + //If method is dev_accountBalance then get the contract_address query param and assign it to raw_value + if self.method.to_string() == "dev_accountBalance".to_string() { + if let Some(query) = req.uri().query() { + let params: HashMap<_, _> = + form_urlencoded::parse(query.as_bytes()).into_owned().collect(); + if let Some(address) = params.get("contract_address") { + let json_string = format!(r#"{{"address":"{}"}}"#, address); + raw_value = Some(JsonRawValue::from_string(json_string).unwrap()); + } + } + } + // RPC methods are accessed with `POST`. *req.method_mut() = Method::POST; // Precautionary remove the URI. @@ -117,9 +134,7 @@ where req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); // Adjust the body to reflect the method call. - let raw_value = JsonRawValue::from_string("{\"address\":\"0x6677fe62ee39c7b07401f754138502bab7fac99d2d3c5d37df7d1c6fab10819\", \"age\":5, \"name\":\"somename\"}".to_string()).unwrap(); - let param = Some(raw_value.as_ref()); - + let param = raw_value.as_ref().map(|value| value.as_ref()); let body = Body::from( serde_json::to_string(&RequestSer::borrowed(&Id::Number(0), &self.method, param)) .expect("Valid request; qed"), From 3ec9a5e7a296d0715494c2445a05f575251c759c Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Wed, 30 Oct 2024 21:55:10 -0600 Subject: [PATCH 10/21] cargo fmt --- crates/katana/rpc/rpc/src/dev.rs | 5 +++-- crates/katana/rpc/rpc/src/proxy_get_request.rs | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index 5eaa5f639e..6eb17474bf 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -5,7 +5,9 @@ use jsonrpsee::core::{async_trait, Error}; use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::ExecutorFactory; -use katana_primitives::genesis::constant::{get_fee_token_balance_base_storage_address, ERC20_NAME_STORAGE_SLOT}; +use katana_primitives::genesis::constant::{ + get_fee_token_balance_base_storage_address, ERC20_NAME_STORAGE_SLOT, +}; use katana_primitives::ContractAddress; use katana_provider::traits::state::StateFactoryProvider; use katana_rpc_api::dev::DevApiServer; @@ -126,6 +128,5 @@ impl DevApiServer for DevApi { async fn predeployed_accounts(&self) -> Result, Error> { Ok(self.backend.chain_spec.genesis.accounts().map(|e| Account::new(*e.0, e.1)).collect()) - } } diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index b702ef6c07..210e7526cb 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -112,7 +112,8 @@ where if modify { let mut raw_value = None; - //If method is dev_accountBalance then get the contract_address query param and assign it to raw_value + // If method is dev_accountBalance then get the contract_address query param and assign + // it to raw_value if self.method.to_string() == "dev_accountBalance".to_string() { if let Some(query) = req.uri().query() { let params: HashMap<_, _> = From 378ae069ac3651672cd41d12f4423dc83c0e0e86 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Sat, 7 Dec 2024 10:06:23 -0600 Subject: [PATCH 11/21] Remove unused files --- crates/katana/rpc/rpc/src/lib.rs | 1 - crates/katana/rpc/rpc/src/server.rs | 928 --------------------- crates/katana/rpc/rpc/src/transport/mod.rs | 1 - crates/katana/rpc/rpc/src/transport/ws.rs | 614 -------------- 4 files changed, 1544 deletions(-) delete mode 100644 crates/katana/rpc/rpc/src/server.rs delete mode 100644 crates/katana/rpc/rpc/src/transport/ws.rs diff --git a/crates/katana/rpc/rpc/src/lib.rs b/crates/katana/rpc/rpc/src/lib.rs index 32e2e445bd..14c482e6bb 100644 --- a/crates/katana/rpc/rpc/src/lib.rs +++ b/crates/katana/rpc/rpc/src/lib.rs @@ -12,6 +12,5 @@ pub mod torii; mod future; mod logger; -mod server; mod transport; mod utils; diff --git a/crates/katana/rpc/rpc/src/server.rs b/crates/katana/rpc/rpc/src/server.rs deleted file mode 100644 index 8c8fdadd0b..0000000000 --- a/crates/katana/rpc/rpc/src/server.rs +++ /dev/null @@ -1,928 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use std::error::Error as StdError; -use std::future::Future; -use std::net::{SocketAddr, TcpListener as StdTcpListener}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Duration; - -use futures_util::future::{BoxFuture, FutureExt}; -use futures_util::io::{BufReader, BufWriter}; -use hyper::body::HttpBody; -use jsonrpsee_core::id_providers::RandomIntegerIdProvider; -use jsonrpsee_core::server::helpers::MethodResponse; -use jsonrpsee_core::server::host_filtering::AllowHosts; -use jsonrpsee_core::server::resource_limiting::Resources; -use jsonrpsee_core::server::rpc_module::Methods; -use jsonrpsee_core::traits::IdProvider; -use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES}; -use soketto::handshake::http::is_upgrade_request; -use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio::sync::{watch, OwnedSemaphorePermit}; -use tokio_util::compat::TokioAsyncReadCompatExt; -use tower::layer::util::Identity; -use tower::{Layer, Service}; -use tracing::{instrument, Instrument}; - -use crate::future::{ConnectionGuard, FutureDriver, ServerHandle, StopHandle}; -use crate::logger::{Logger, TransportProtocol}; -use crate::transport::{http, ws}; - -/// Default maximum connections allowed. -const MAX_CONNECTIONS: u32 = 100; - -/// JSON RPC server. -pub struct Server { - listener: TcpListener, - cfg: Settings, - resources: Resources, - logger: L, - id_provider: Arc, - service_builder: tower::ServiceBuilder, -} - -impl std::fmt::Debug for Server { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Server") - .field("listener", &self.listener) - .field("cfg", &self.cfg) - .field("id_provider", &self.id_provider) - .field("resources", &self.resources) - .finish() - } -} - -impl Server { - /// Returns socket address to which the server is bound. - pub fn local_addr(&self) -> Result { - self.listener.local_addr().map_err(Into::into) - } -} - -impl Server -where - L: Logger, - B: Layer> + Send + 'static, - >>::Service: Send - + Service< - hyper::Request, - Response = hyper::Response, - Error = Box<(dyn StdError + Send + Sync + 'static)>, - >, - <>>::Service as Service>>::Future: Send, - U: HttpBody + Send + 'static, - ::Error: Send + Sync + StdError, - ::Data: Send, -{ - /// Start responding to connections requests. - /// - /// This will run on the tokio runtime until the server is stopped or the `ServerHandle` is - /// dropped. - pub fn start(mut self, methods: impl Into) -> Result { - let methods = methods.into().initialize_resources(&self.resources)?; - let (stop_tx, stop_rx) = watch::channel(()); - - let stop_handle = StopHandle::new(stop_rx); - - match self.cfg.tokio_runtime.take() { - Some(rt) => rt.spawn(self.start_inner(methods, stop_handle)), - None => tokio::spawn(self.start_inner(methods, stop_handle)), - }; - - Ok(ServerHandle::new(stop_tx)) - } - - async fn start_inner(self, methods: Methods, stop_handle: StopHandle) { - let max_request_body_size = self.cfg.max_request_body_size; - let max_response_body_size = self.cfg.max_response_body_size; - let max_log_length = self.cfg.max_log_length; - let allow_hosts = self.cfg.allow_hosts; - let resources = self.resources; - let logger = self.logger; - let batch_requests_supported = self.cfg.batch_requests_supported; - let id_provider = self.id_provider; - let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection; - - let mut id: u32 = 0; - let connection_guard = ConnectionGuard::new(self.cfg.max_connections as usize); - let mut connections = FutureDriver::default(); - let mut incoming = Monitored::new(Incoming(self.listener), &stop_handle); - - loop { - match connections.select_with(&mut incoming).await { - Ok((socket, remote_addr)) => { - let data = ProcessConnection { - remote_addr, - methods: methods.clone(), - allow_hosts: allow_hosts.clone(), - resources: resources.clone(), - max_request_body_size, - max_response_body_size, - max_log_length, - batch_requests_supported, - id_provider: id_provider.clone(), - ping_interval: self.cfg.ping_interval, - stop_handle: stop_handle.clone(), - max_subscriptions_per_connection, - conn_id: id, - logger: logger.clone(), - max_connections: self.cfg.max_connections, - enable_http: self.cfg.enable_http, - enable_ws: self.cfg.enable_ws, - }; - process_connection( - &self.service_builder, - &connection_guard, - data, - socket, - &mut connections, - ); - id = id.wrapping_add(1); - } - Err(MonitoredError::Selector(err)) => { - tracing::error!("Error while awaiting a new connection: {:?}", err); - } - Err(MonitoredError::Shutdown) => break, - } - } - - connections.await; - } -} - -/// JSON-RPC Websocket server settings. -#[derive(Debug, Clone)] -struct Settings { - /// Maximum size in bytes of a request. - max_request_body_size: u32, - /// Maximum size in bytes of a response. - max_response_body_size: u32, - /// Maximum number of incoming connections allowed. - max_connections: u32, - /// Maximum number of subscriptions per connection. - max_subscriptions_per_connection: u32, - /// Max length for logging for requests and responses - /// - /// Logs bigger than this limit will be truncated. - max_log_length: u32, - /// Host filtering. - allow_hosts: AllowHosts, - /// Whether batch requests are supported by this server or not. - batch_requests_supported: bool, - /// Custom tokio runtime to run the server on. - tokio_runtime: Option, - /// The interval at which `Ping` frames are submitted. - ping_interval: Duration, - /// Enable HTTP. - enable_http: bool, - /// Enable WS. - enable_ws: bool, -} - -impl Default for Settings { - fn default() -> Self { - Self { - max_request_body_size: TEN_MB_SIZE_BYTES, - max_response_body_size: TEN_MB_SIZE_BYTES, - max_log_length: 4096, - max_subscriptions_per_connection: 1024, - max_connections: MAX_CONNECTIONS, - batch_requests_supported: true, - allow_hosts: AllowHosts::Any, - tokio_runtime: None, - ping_interval: Duration::from_secs(60), - enable_http: true, - enable_ws: true, - } - } -} - -/// Builder to configure and create a JSON-RPC server -#[derive(Debug)] -pub struct Builder { - settings: Settings, - resources: Resources, - logger: L, - id_provider: Arc, - service_builder: tower::ServiceBuilder, -} - -impl Default for Builder { - fn default() -> Self { - Builder { - settings: Settings::default(), - resources: Resources::default(), - logger: (), - id_provider: Arc::new(RandomIntegerIdProvider), - service_builder: tower::ServiceBuilder::new(), - } - } -} - -impl Builder { - /// Create a default server builder. - pub fn new() -> Self { - Self::default() - } -} - -impl Builder { - /// Set the maximum size of a request body in bytes. Default is 10 MiB. - pub fn max_request_body_size(mut self, size: u32) -> Self { - self.settings.max_request_body_size = size; - self - } - - /// Set the maximum size of a response body in bytes. Default is 10 MiB. - pub fn max_response_body_size(mut self, size: u32) -> Self { - self.settings.max_response_body_size = size; - self - } - - /// Set the maximum number of connections allowed. Default is 100. - pub fn max_connections(mut self, max: u32) -> Self { - self.settings.max_connections = max; - self - } - - /// Enables or disables support of [batch requests](https://www.jsonrpc.org/specification#batch). - /// By default, support is enabled. - pub fn batch_requests_supported(mut self, supported: bool) -> Self { - self.settings.batch_requests_supported = supported; - self - } - - /// Set the maximum number of connections allowed. Default is 1024. - pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self { - self.settings.max_subscriptions_per_connection = max; - self - } - - /// Register a new resource kind. Errors if `label` is already registered, or if the number of - /// registered resources on this server instance would exceed 8. - /// - /// See the module documentation for - /// [`resurce_limiting`](../jsonrpsee_utils/server/resource_limiting/index.html# - /// resource-limiting) for details. - pub fn register_resource( - mut self, - label: &'static str, - capacity: u16, - default: u16, - ) -> Result { - self.resources.register(label, capacity, default)?; - Ok(self) - } - - /// Add a logger to the builder [`Logger`](../jsonrpsee_core/logger/trait.Logger.html). - /// - /// ``` - /// use std::net::SocketAddr; - /// use std::time::Instant; - /// - /// use jsonrpsee_server::logger::{HttpRequest, Logger, MethodKind, Params, TransportProtocol}; - /// use jsonrpsee_server::ServerBuilder; - /// - /// #[derive(Clone)] - /// struct MyLogger; - /// - /// impl Logger for MyLogger { - /// type Instant = Instant; - /// - /// fn on_connect( - /// &self, - /// remote_addr: SocketAddr, - /// request: &HttpRequest, - /// transport: TransportProtocol, - /// ) { - /// println!( - /// "[MyLogger::on_call] remote_addr: {:?}, headers: {:?}, transport: {}", - /// remote_addr, request, transport - /// ); - /// } - /// - /// fn on_request(&self, transport: TransportProtocol) -> Self::Instant { - /// Instant::now() - /// } - /// - /// fn on_call( - /// &self, - /// method_name: &str, - /// params: Params, - /// kind: MethodKind, - /// transport: TransportProtocol, - /// ) { - /// println!( - /// "[MyLogger::on_call] method: '{}' params: {:?}, kind: {:?}, transport: {}", - /// method_name, params, kind, transport - /// ); - /// } - /// - /// fn on_result( - /// &self, - /// method_name: &str, - /// success: bool, - /// started_at: Self::Instant, - /// transport: TransportProtocol, - /// ) { - /// println!( - /// "[MyLogger::on_result] '{}', worked? {}, time elapsed {:?}, transport: {}", - /// method_name, - /// success, - /// started_at.elapsed(), - /// transport - /// ); - /// } - /// - /// fn on_response( - /// &self, - /// result: &str, - /// started_at: Self::Instant, - /// transport: TransportProtocol, - /// ) { - /// println!( - /// "[MyLogger::on_response] result: {}, time elapsed {:?}, transport: {}", - /// result, - /// started_at.elapsed(), - /// transport - /// ); - /// } - /// - /// fn on_disconnect(&self, remote_addr: SocketAddr, transport: TransportProtocol) { - /// println!( - /// "[MyLogger::on_disconnect] remote_addr: {:?}, transport: {}", - /// remote_addr, transport - /// ); - /// } - /// } - /// - /// let builder = ServerBuilder::new().set_logger(MyLogger); - /// ``` - pub fn set_logger(self, logger: T) -> Builder { - Builder { - settings: self.settings, - resources: self.resources, - logger, - id_provider: self.id_provider, - service_builder: self.service_builder, - } - } - - /// Configure a custom [`tokio::runtime::Handle`] to run the server on. - /// - /// Default: [`tokio::spawn`] - pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self { - self.settings.tokio_runtime = Some(rt); - self - } - - /// Configure the interval at which pings are submitted. - /// - /// This option is used to keep the connection alive, and is just submitting `Ping` frames, - /// without making any assumptions about when a `Pong` frame should be received. - /// - /// Default: 60 seconds. - /// - /// # Examples - /// - /// ```rust - /// use std::time::Duration; - /// - /// use jsonrpsee_server::ServerBuilder; - /// - /// // Set the ping interval to 10 seconds. - /// let builder = ServerBuilder::default().ping_interval(Duration::from_secs(10)); - /// ``` - pub fn ping_interval(mut self, interval: Duration) -> Self { - self.settings.ping_interval = interval; - self - } - - /// Configure custom `subscription ID` provider for the server to use - /// to when getting new subscription calls. - /// - /// You may choose static dispatch or dynamic dispatch because - /// `IdProvider` is implemented for `Box`. - /// - /// Default: [`RandomIntegerIdProvider`]. - /// - /// # Examples - /// - /// ```rust - /// use jsonrpsee_server::{IdProvider, RandomStringIdProvider, ServerBuilder}; - /// - /// // static dispatch - /// let builder1 = ServerBuilder::default().set_id_provider(RandomStringIdProvider::new(16)); - /// - /// // or dynamic dispatch - /// let builder2 = - /// ServerBuilder::default().set_id_provider(Box::new(RandomStringIdProvider::new(16))); - /// ``` - pub fn set_id_provider(mut self, id_provider: I) -> Self { - self.id_provider = Arc::new(id_provider); - self - } - - /// Sets host filtering. - pub fn set_host_filtering(mut self, allow: AllowHosts) -> Self { - self.settings.allow_hosts = allow; - self - } - - /// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied - /// to the RPC service. - /// - /// Default: No tower layers are applied to the RPC service. - /// - /// # Examples - /// - /// ```rust - /// use std::net::SocketAddr; - /// use std::time::Duration; - /// - /// #[tokio::main] - /// async fn main() { - /// let builder = tower::ServiceBuilder::new().timeout(Duration::from_secs(2)); - /// - /// let server = jsonrpsee_server::ServerBuilder::new() - /// .set_middleware(builder) - /// .build("127.0.0.1:0".parse::().unwrap()) - /// .await - /// .unwrap(); - /// } - /// ``` - pub fn set_middleware(self, service_builder: tower::ServiceBuilder) -> Builder { - Builder { - settings: self.settings, - resources: self.resources, - logger: self.logger, - id_provider: self.id_provider, - service_builder, - } - } - - /// Configure the server to only serve JSON-RPC HTTP requests. - /// - /// Default: both http and ws are enabled. - pub fn http_only(mut self) -> Self { - self.settings.enable_http = true; - self.settings.enable_ws = false; - self - } - - /// Configure the server to only serve JSON-RPC WebSocket requests. - /// - /// That implies that server just denies HTTP requests which isn't a WebSocket upgrade request - /// - /// Default: both http and ws are enabled. - pub fn ws_only(mut self) -> Self { - self.settings.enable_http = false; - self.settings.enable_ws = true; - self - } - - /// Finalize the configuration of the server. Consumes the [`Builder`]. - /// - /// ```rust - /// #[tokio::main] - /// async fn main() { - /// let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - /// let occupied_addr = listener.local_addr().unwrap(); - /// let addrs: &[std::net::SocketAddr] = &[occupied_addr, "127.0.0.1:0".parse().unwrap()]; - /// assert!(jsonrpsee_server::ServerBuilder::default().build(occupied_addr).await.is_err()); - /// assert!(jsonrpsee_server::ServerBuilder::default().build(addrs).await.is_ok()); - /// } - /// ``` - pub async fn build(self, addrs: impl ToSocketAddrs) -> Result, Error> { - let listener = TcpListener::bind(addrs).await?; - - Ok(Server { - listener, - cfg: self.settings, - resources: self.resources, - logger: self.logger, - id_provider: self.id_provider, - service_builder: self.service_builder, - }) - } - - /// Finalizes the configuration of the server with customized TCP settings on the socket. - /// - /// - /// ```rust - /// use std::time::Duration; - /// - /// use jsonrpsee_server::ServerBuilder; - /// use socket2::{Domain, Socket, Type}; - /// - /// #[tokio::main] - /// async fn main() { - /// let addr = "127.0.0.1:0".parse().unwrap(); - /// let domain = Domain::for_address(addr); - /// let socket = Socket::new(domain, Type::STREAM, None).unwrap(); - /// socket.set_nonblocking(true).unwrap(); - /// - /// let address = addr.into(); - /// socket.bind(&address).unwrap(); - /// - /// socket.listen(4096).unwrap(); - /// - /// let server = ServerBuilder::new().build_from_tcp(socket).unwrap(); - /// } - /// ``` - pub fn build_from_tcp( - self, - listener: impl Into, - ) -> Result, Error> { - let listener = TcpListener::from_std(listener.into())?; - - Ok(Server { - listener, - cfg: self.settings, - resources: self.resources, - logger: self.logger, - id_provider: self.id_provider, - service_builder: self.service_builder, - }) - } -} - -pub(crate) enum MethodResult { - JustLogger(MethodResponse), - SendAndLogger(MethodResponse), -} - -impl MethodResult { - pub(crate) fn as_inner(&self) -> &MethodResponse { - match &self { - Self::JustLogger(r) => r, - Self::SendAndLogger(r) => r, - } - } - - pub(crate) fn into_inner(self) -> MethodResponse { - match self { - Self::JustLogger(r) => r, - Self::SendAndLogger(r) => r, - } - } -} - -/// Data required by the server to handle requests. -#[derive(Debug, Clone)] -pub(crate) struct ServiceData { - /// Remote server address. - pub(crate) remote_addr: SocketAddr, - /// Registered server methods. - pub(crate) methods: Methods, - /// Access control. - pub(crate) allow_hosts: AllowHosts, - /// Tracker for currently used resources on the server. - pub(crate) resources: Resources, - /// Max request body size. - pub(crate) max_request_body_size: u32, - /// Max response body size. - pub(crate) max_response_body_size: u32, - /// Max length for logging for request and response - /// - /// Logs bigger than this limit will be truncated. - pub(crate) max_log_length: u32, - /// Whether batch requests are supported by this server or not. - pub(crate) batch_requests_supported: bool, - /// Subscription ID provider. - pub(crate) id_provider: Arc, - /// Ping interval - pub(crate) ping_interval: Duration, - /// Stop handle. - pub(crate) stop_handle: StopHandle, - /// Max subscriptions per connection. - pub(crate) max_subscriptions_per_connection: u32, - /// Connection ID - pub(crate) conn_id: u32, - /// Logger. - pub(crate) logger: L, - /// Handle to hold a `connection permit`. - pub(crate) conn: Arc, - /// Enable HTTP. - pub(crate) enable_http: bool, - /// Enable WS. - pub(crate) enable_ws: bool, -} - -/// JsonRPSee service compatible with `tower`. -/// -/// # Note -/// This is similar to [`hyper::service::service_fn`]. -#[derive(Debug, Clone)] -pub struct TowerService { - inner: ServiceData, -} - -impl hyper::service::Service> for TowerService { - type Response = hyper::Response; - - // The following associated type is required by the `impl Server` bounds. - // It satisfies the server's bounds when the `tower::ServiceBuilder` is not set (ie `B: - // Identity`). - type Error = Box; - - type Future = Pin> + Send>>; - - /// Opens door for back pressure implementation. - fn poll_ready(&mut self, _: &mut Context) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, request: hyper::Request) -> Self::Future { - tracing::trace!("{:?}", request); - - let host = match http_helpers::read_header_value(request.headers(), hyper::header::HOST) { - Some(host) => host, - None if request.version() == hyper::Version::HTTP_2 => match request.uri().host() { - Some(host) => host, - None => return async move { Ok(http::response::malformed()) }.boxed(), - }, - None => return async move { Ok(http::response::malformed()) }.boxed(), - }; - - if let Err(e) = self.inner.allow_hosts.verify(host) { - tracing::warn!("Denied request: {}", e); - return async { Ok(http::response::host_not_allowed()) }.boxed(); - } - - let is_upgrade_request = is_upgrade_request(&request); - - if self.inner.enable_ws && is_upgrade_request { - let mut server = soketto::handshake::http::Server::new(); - - let response = match server.receive_request(&request) { - Ok(response) => { - self.inner.logger.on_connect( - self.inner.remote_addr, - &request, - TransportProtocol::WebSocket, - ); - let data = self.inner.clone(); - - tokio::spawn( - async move { - let upgraded = match hyper::upgrade::on(request).await { - Ok(u) => u, - Err(e) => { - tracing::warn!("Could not upgrade connection: {}", e); - return; - } - }; - - let stream = BufReader::new(BufWriter::new(upgraded.compat())); - let mut ws_builder = server.into_builder(stream); - ws_builder.set_max_message_size(data.max_request_body_size as usize); - let (sender, receiver) = ws_builder.finish(); - - let _ = ws::background_task::(sender, receiver, data).await; - } - .in_current_span(), - ); - - response.map(|()| hyper::Body::empty()) - } - Err(e) => { - tracing::error!("Could not upgrade connection: {}", e); - hyper::Response::new(hyper::Body::from(format!( - "Could not upgrade connection: {}", - e - ))) - } - }; - - async { Ok(response) }.boxed() - } else if self.inner.enable_http && !is_upgrade_request { - // The request wasn't an upgrade request; let's treat it as a standard HTTP request: - let data = http::HandleRequest { - methods: self.inner.methods.clone(), - resources: self.inner.resources.clone(), - max_request_body_size: self.inner.max_request_body_size, - max_response_body_size: self.inner.max_response_body_size, - max_log_length: self.inner.max_log_length, - batch_requests_supported: self.inner.batch_requests_supported, - logger: self.inner.logger.clone(), - conn: self.inner.conn.clone(), - remote_addr: self.inner.remote_addr, - }; - - self.inner.logger.on_connect(self.inner.remote_addr, &request, TransportProtocol::Http); - - Box::pin(http::handle_request(request, data).map(Ok)) - } else { - Box::pin(async { http::response::denied() }.map(Ok)) - } - } -} - -/// This is a glorified select listening for new messages, while also checking the `stop_receiver` -/// signal. -struct Monitored<'a, F> { - future: F, - stop_monitor: &'a StopHandle, -} - -impl<'a, F> Monitored<'a, F> { - fn new(future: F, stop_monitor: &'a StopHandle) -> Self { - Monitored { future, stop_monitor } - } -} - -enum MonitoredError { - Shutdown, - Selector(E), -} - -struct Incoming(TcpListener); - -impl<'a> Future for Monitored<'a, Incoming> { - type Output = Result<(TcpStream, SocketAddr), MonitoredError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let this = Pin::into_inner(self); - - if this.stop_monitor.shutdown_requested() { - return Poll::Ready(Err(MonitoredError::Shutdown)); - } - - this.future.0.poll_accept(cx).map_err(MonitoredError::Selector) - } -} - -impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>> -where - F: Future>, -{ - type Output = Result>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let this = Pin::into_inner(self); - - if this.stop_monitor.shutdown_requested() { - return Poll::Ready(Err(MonitoredError::Shutdown)); - } - - this.future.poll_unpin(cx).map_err(MonitoredError::Selector) - } -} - -struct ProcessConnection { - /// Remote server address. - remote_addr: SocketAddr, - /// Registered server methods. - methods: Methods, - /// Access control. - allow_hosts: AllowHosts, - /// Tracker for currently used resources on the server. - resources: Resources, - /// Max request body size. - max_request_body_size: u32, - /// Max response body size. - max_response_body_size: u32, - /// Max length for logging for request and response - /// - /// Logs bigger than this limit will be truncated. - max_log_length: u32, - /// Whether batch requests are supported by this server or not. - batch_requests_supported: bool, - /// Subscription ID provider. - id_provider: Arc, - /// Ping interval - ping_interval: Duration, - /// Stop handle. - stop_handle: StopHandle, - /// Max subscriptions per connection. - max_subscriptions_per_connection: u32, - /// Max connections, - max_connections: u32, - /// Connection ID - conn_id: u32, - /// Logger. - logger: L, - /// Allow JSON-RPC HTTP requests. - enable_http: bool, - /// Allow JSON-RPC WS request and WS upgrade requests. - enable_ws: bool, -} - -#[instrument(name = "connection", skip_all, fields(remote_addr = %cfg.remote_addr, conn_id = %cfg.conn_id), level = "INFO")] -fn process_connection<'a, L: Logger, B, U>( - service_builder: &tower::ServiceBuilder, - connection_guard: &ConnectionGuard, - cfg: ProcessConnection, - socket: TcpStream, - connections: &mut FutureDriver>, -) where - B: Layer> + Send + 'static, - >>::Service: Send - + Service< - hyper::Request, - Response = hyper::Response, - Error = Box<(dyn StdError + Send + Sync + 'static)>, - >, - <>>::Service as Service>>::Future: Send, - U: HttpBody + Send + 'static, - ::Error: Send + Sync + StdError, - ::Data: Send, -{ - if let Err(e) = socket.set_nodelay(true) { - tracing::warn!("Could not set NODELAY on socket: {:?}", e); - return; - } - - let conn = match connection_guard.try_acquire() { - Some(conn) => conn, - None => { - tracing::warn!("Too many connections. Please try again later."); - connections.add(http::reject_connection(socket).in_current_span().boxed()); - return; - } - }; - - let max_conns = cfg.max_connections as usize; - let curr_conns = max_conns - connection_guard.available_connections(); - tracing::info!("Accepting new connection {}/{}", curr_conns, max_conns); - - let tower_service = TowerService { - inner: ServiceData { - remote_addr: cfg.remote_addr, - methods: cfg.methods, - allow_hosts: cfg.allow_hosts, - resources: cfg.resources, - max_request_body_size: cfg.max_request_body_size, - max_response_body_size: cfg.max_response_body_size, - max_log_length: cfg.max_log_length, - batch_requests_supported: cfg.batch_requests_supported, - id_provider: cfg.id_provider, - ping_interval: cfg.ping_interval, - stop_handle: cfg.stop_handle.clone(), - max_subscriptions_per_connection: cfg.max_subscriptions_per_connection, - conn_id: cfg.conn_id, - logger: cfg.logger, - conn: Arc::new(conn), - enable_http: cfg.enable_http, - enable_ws: cfg.enable_ws, - }, - }; - - let service = service_builder.service(tower_service); - - connections - .add(Box::pin(try_accept_connection(socket, service, cfg.stop_handle).in_current_span())); -} - -// Attempts to create a HTTP connection from a socket. -async fn try_accept_connection(socket: TcpStream, service: S, mut stop_handle: StopHandle) -where - S: Service, Response = hyper::Response> + Send + 'static, - S::Error: Into>, - S::Future: Send, - Bd: HttpBody + Send + 'static, - ::Error: Send + Sync + StdError, - ::Data: Send, -{ - let conn = hyper::server::conn::Http::new().serve_connection(socket, service).with_upgrades(); - - tokio::pin!(conn); - - tokio::select! { - res = &mut conn => { - if let Err(e) = res { - tracing::warn!("HTTP serve connection failed {:?}", e); - } - } - _ = stop_handle.shutdown() => { - conn.graceful_shutdown(); - } - } -} diff --git a/crates/katana/rpc/rpc/src/transport/mod.rs b/crates/katana/rpc/rpc/src/transport/mod.rs index 8b293e269f..d064c8bd6f 100644 --- a/crates/katana/rpc/rpc/src/transport/mod.rs +++ b/crates/katana/rpc/rpc/src/transport/mod.rs @@ -1,2 +1 @@ pub(crate) mod http; -pub(crate) mod ws; diff --git a/crates/katana/rpc/rpc/src/transport/ws.rs b/crates/katana/rpc/rpc/src/transport/ws.rs deleted file mode 100644 index 600384906e..0000000000 --- a/crates/katana/rpc/rpc/src/transport/ws.rs +++ /dev/null @@ -1,614 +0,0 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; - -use futures_channel::mpsc; -use futures_util::future::{self, Either}; -use futures_util::io::{BufReader, BufWriter}; -use futures_util::stream::FuturesOrdered; -use futures_util::{Future, FutureExt, StreamExt}; -use hyper::upgrade::Upgraded; -use jsonrpsee_core::server::helpers::{ - prepare_error, BatchResponse, BatchResponseBuilder, BoundedSubscriptions, MethodResponse, - MethodSink, -}; -use jsonrpsee_core::server::resource_limiting::Resources; -use jsonrpsee_core::server::rpc_module::{ConnState, MethodKind, Methods}; -use jsonrpsee_core::tracing::{rx_log_from_json, tx_log_from_str}; -use jsonrpsee_core::traits::IdProvider; -use jsonrpsee_core::{Error, JsonRawValue}; -use jsonrpsee_types::error::{ - reject_too_big_request, reject_too_many_subscriptions, ErrorCode, BATCHES_NOT_SUPPORTED_CODE, - BATCHES_NOT_SUPPORTED_MSG, -}; -use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Params, Request}; -use soketto::connection::Error as SokettoError; -use soketto::data::ByteSlice125; -use tokio_stream::wrappers::IntervalStream; -use tokio_util::compat::Compat; -use tracing::instrument; - -use crate::future::{FutureDriver, StopHandle}; -use crate::logger::{self, Logger, TransportProtocol}; -use crate::server::{MethodResult, ServiceData}; - -pub(crate) type Sender = soketto::Sender>>>; -pub(crate) type Receiver = soketto::Receiver>>>; - -pub(crate) async fn send_message(sender: &mut Sender, response: String) -> Result<(), Error> { - sender.send_text_owned(response).await?; - sender.flush().await.map_err(Into::into) -} - -pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), Error> { - tracing::debug!("Send ping"); - // Submit empty slice as "optional" parameter. - let slice: &[u8] = &[]; - // Byte slice fails if the provided slice is larger than 125 bytes. - let byte_slice = - ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125"); - sender.send_ping(byte_slice).await?; - sender.flush().await.map_err(Into::into) -} - -#[derive(Debug, Clone)] -pub(crate) struct Batch<'a, L: Logger> { - pub(crate) data: Vec, - pub(crate) call: CallData<'a, L>, -} - -#[derive(Debug, Clone)] -pub(crate) struct CallData<'a, L: Logger> { - pub(crate) conn_id: usize, - pub(crate) bounded_subscriptions: BoundedSubscriptions, - pub(crate) id_provider: &'a dyn IdProvider, - pub(crate) methods: &'a Methods, - pub(crate) max_response_body_size: u32, - pub(crate) max_log_length: u32, - pub(crate) resources: &'a Resources, - pub(crate) sink: &'a MethodSink, - pub(crate) logger: &'a L, - pub(crate) request_start: L::Instant, -} - -/// This is a glorified select listening for new messages, while also checking the `stop_receiver` -/// signal. -struct Monitored<'a, F> { - future: F, - stop_monitor: &'a StopHandle, -} - -impl<'a, F> Monitored<'a, F> { - fn new(future: F, stop_monitor: &'a StopHandle) -> Self { - Monitored { future, stop_monitor } - } -} - -enum MonitoredError { - Shutdown, - Selector(E), -} - -impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>> -where - F: Future>, -{ - type Output = Result>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let this = Pin::into_inner(self); - - if this.stop_monitor.shutdown_requested() { - return Poll::Ready(Err(MonitoredError::Shutdown)); - } - - this.future.poll_unpin(cx).map_err(MonitoredError::Selector) - } -} - -// Batch responses must be sent back as a single message so we read the results from each -// request in the batch and read the results off of a new channel, `rx_batch`, and then send the -// complete batch response back to the client over `tx`. -#[instrument(name = "batch", skip(b), level = "TRACE")] -pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> Option { - let Batch { data, call } = b; - - if let Ok(batch) = serde_json::from_slice::>(&data) { - let mut got_notif = false; - let mut batch_response = - BatchResponseBuilder::new_with_limit(call.max_response_body_size as usize); - - let mut pending_calls: FuturesOrdered<_> = batch - .into_iter() - .filter_map(|v| { - if let Ok(req) = serde_json::from_str::(v.get()) { - Some(Either::Right(async { - execute_call(req, call.clone()).await.into_inner() - })) - } else if let Ok(_notif) = - serde_json::from_str::>(v.get()) - { - // notifications should not be answered. - got_notif = true; - None - } else { - // valid JSON but could be not parsable as `InvalidRequest` - let id = match serde_json::from_str::(v.get()) { - Ok(err) => err.id, - Err(_) => Id::Null, - }; - - Some(Either::Left(async { - MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)) - })) - } - }) - .collect(); - - while let Some(response) = pending_calls.next().await { - if let Err(too_large) = batch_response.append(&response) { - return Some(too_large); - } - } - - if got_notif && batch_response.is_empty() { None } else { Some(batch_response.finish()) } - } else { - Some(BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError))) - } -} - -pub(crate) async fn process_single_request( - data: Vec, - call: CallData<'_, L>, -) -> MethodResult { - if let Ok(req) = serde_json::from_slice::(&data) { - execute_call_with_tracing(req, call).await - } else { - let (id, code) = prepare_error(&data); - MethodResult::SendAndLogger(MethodResponse::error(id, ErrorObject::from(code))) - } -} - -#[instrument(name = "method_call", fields(method = req.method.as_ref()), skip(call, req), level = "TRACE")] -pub(crate) async fn execute_call_with_tracing<'a, L: Logger>( - req: Request<'a>, - call: CallData<'_, L>, -) -> MethodResult { - execute_call(req, call).await -} - -/// Execute a call which returns result of the call with a additional sink -/// to fire a signal once the subscription call has been answered. -/// -/// Returns `(MethodResponse, None)` on every call that isn't a subscription -/// Otherwise `(MethodResponse, Some(PendingSubscriptionCallTx)`. -pub(crate) async fn execute_call<'a, L: Logger>( - req: Request<'a>, - call: CallData<'_, L>, -) -> MethodResult { - let CallData { - resources, - methods, - max_response_body_size, - max_log_length, - conn_id, - bounded_subscriptions, - id_provider, - sink, - logger, - request_start, - } = call; - - rx_log_from_json(&req, call.max_log_length); - - let params = Params::new(req.params.map(|params| params.get())); - let name = &req.method; - let id = req.id; - - let response = match methods.method_with_name(name) { - None => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::Unknown, - TransportProtocol::WebSocket, - ); - let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)); - MethodResult::SendAndLogger(response) - } - Some((name, method)) => match &method.inner() { - MethodKind::Sync(callback) => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::MethodCall, - TransportProtocol::WebSocket, - ); - match method.claim(name, resources) { - Ok(guard) => { - let r = (callback)(id, params, max_response_body_size as usize); - drop(guard); - MethodResult::SendAndLogger(r) - } - Err(err) => { - tracing::error!( - "[Methods::execute_with_resources] failed to lock resources: {}", - err - ); - let response = - MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); - MethodResult::SendAndLogger(response) - } - } - } - MethodKind::Async(callback) => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::MethodCall, - TransportProtocol::WebSocket, - ); - match method.claim(name, resources) { - Ok(guard) => { - let id = id.into_owned(); - let params = params.into_owned(); - - let response = (callback)( - id, - params, - conn_id, - max_response_body_size as usize, - Some(guard), - ) - .await; - MethodResult::SendAndLogger(response) - } - Err(err) => { - tracing::error!( - "[Methods::execute_with_resources] failed to lock resources: {}", - err - ); - let response = - MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); - MethodResult::SendAndLogger(response) - } - } - } - MethodKind::Subscription(callback) => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::Subscription, - TransportProtocol::WebSocket, - ); - match method.claim(name, resources) { - Ok(guard) => { - if let Some(cn) = bounded_subscriptions.acquire() { - let conn_state = ConnState { conn_id, close_notify: cn, id_provider }; - let response = - callback(id.clone(), params, sink.clone(), conn_state, Some(guard)) - .await; - MethodResult::JustLogger(response) - } else { - let response = MethodResponse::error( - id, - reject_too_many_subscriptions(bounded_subscriptions.max()), - ); - MethodResult::SendAndLogger(response) - } - } - Err(err) => { - tracing::error!( - "[Methods::execute_with_resources] failed to lock resources: {}", - err - ); - let response = - MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)); - MethodResult::SendAndLogger(response) - } - } - } - MethodKind::Unsubscription(callback) => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::Unsubscription, - TransportProtocol::WebSocket, - ); - - // Don't adhere to any resource or subscription limits; always let unsubscribing - // happen! - let result = callback(id, params, conn_id, max_response_body_size as usize); - MethodResult::SendAndLogger(result) - } - }, - }; - - let r = response.as_inner(); - - tx_log_from_str(&r.result, max_log_length); - logger.on_result(name, r.success, request_start, TransportProtocol::WebSocket); - response -} - -pub(crate) async fn background_task( - sender: Sender, - mut receiver: Receiver, - svc: ServiceData, -) -> Result<(), Error> { - let ServiceData { - methods, - resources, - max_request_body_size, - max_response_body_size, - max_log_length, - batch_requests_supported, - stop_handle, - id_provider, - ping_interval, - max_subscriptions_per_connection, - conn_id, - logger, - remote_addr, - conn, - .. - } = svc; - - let (tx, rx) = mpsc::unbounded::(); - let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection); - let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length); - - // Spawn another task that sends out the responses on the Websocket. - tokio::spawn(send_task(rx, sender, stop_handle.clone(), ping_interval)); - - // Buffer for incoming data. - let mut data = Vec::with_capacity(100); - let mut method_executors = FutureDriver::default(); - let logger = &logger; - - let result = loop { - data.clear(); - - { - // Need the extra scope to drop this pinned future and reclaim access to `data` - let receive = async { - // Identical loop to `soketto::receive_data` with debug logs for `Pong` frames. - loop { - match receiver.receive(&mut data).await? { - soketto::Incoming::Data(d) => break Ok(d), - soketto::Incoming::Pong(_) => tracing::debug!("Received pong"), - soketto::Incoming::Closed(_) => { - // The closing reason is already logged by `soketto` trace log level. - // Return the `Closed` error to avoid logging unnecessary warnings on - // clean shutdown. - break Err(SokettoError::Closed); - } - } - } - }; - - tokio::pin!(receive); - - if let Err(err) = - method_executors.select_with(Monitored::new(receive, &stop_handle)).await - { - match err { - MonitoredError::Selector(SokettoError::Closed) => { - tracing::debug!( - "WS transport: remote peer terminated the connection: {}", - conn_id - ); - break Ok(()); - } - MonitoredError::Selector(SokettoError::MessageTooLarge { - current, - maximum, - }) => { - tracing::warn!( - "WS transport error: request length: {} exceeded max limit: {} bytes", - current, - maximum - ); - sink.send_error(Id::Null, reject_too_big_request(max_request_body_size)); - continue; - } - - // These errors can not be gracefully handled, so just log them and terminate - // the connection. - MonitoredError::Selector(err) => { - tracing::error!( - "WS transport error: {}; terminate connection: {}", - err, - conn_id - ); - break Err(err.into()); - } - MonitoredError::Shutdown => { - break Ok(()); - } - }; - }; - }; - - let request_start = logger.on_request(TransportProtocol::WebSocket); - - let first_non_whitespace = data.iter().find(|byte| !byte.is_ascii_whitespace()); - match first_non_whitespace { - Some(b'{') => { - let data = std::mem::take(&mut data); - let sink = sink.clone(); - let resources = &resources; - let methods = &methods; - let bounded_subscriptions = bounded_subscriptions.clone(); - let id_provider = &*id_provider; - - let fut = async move { - let call = CallData { - conn_id: conn_id as usize, - resources, - max_response_body_size, - max_log_length, - methods, - bounded_subscriptions, - sink: &sink, - id_provider, - logger, - request_start, - }; - - match process_single_request(data, call).await { - MethodResult::JustLogger(r) => { - logger.on_response( - &r.result, - request_start, - TransportProtocol::WebSocket, - ); - } - MethodResult::SendAndLogger(r) => { - logger.on_response( - &r.result, - request_start, - TransportProtocol::WebSocket, - ); - let _ = sink.send_raw(r.result); - } - }; - } - .boxed(); - - method_executors.add(fut); - } - Some(b'[') if !batch_requests_supported => { - let response = MethodResponse::error( - Id::Null, - ErrorObject::borrowed( - BATCHES_NOT_SUPPORTED_CODE, - &BATCHES_NOT_SUPPORTED_MSG, - None, - ), - ); - logger.on_response(&response.result, request_start, TransportProtocol::WebSocket); - let _ = sink.send_raw(response.result); - } - Some(b'[') => { - // Make sure the following variables are not moved into async closure below. - let resources = &resources; - let methods = &methods; - let bounded_subscriptions = bounded_subscriptions.clone(); - let sink = sink.clone(); - let id_provider = id_provider.clone(); - let data = std::mem::take(&mut data); - - let fut = async move { - let response = process_batch_request(Batch { - data, - call: CallData { - conn_id: conn_id as usize, - resources, - max_response_body_size, - max_log_length, - methods, - bounded_subscriptions, - sink: &sink, - id_provider: &*id_provider, - logger, - request_start, - }, - }) - .await; - - if let Some(response) = response { - tx_log_from_str(&response.result, max_log_length); - logger.on_response( - &response.result, - request_start, - TransportProtocol::WebSocket, - ); - let _ = sink.send_raw(response.result); - } - }; - - method_executors.add(Box::pin(fut)); - } - _ => { - sink.send_error(Id::Null, ErrorCode::ParseError.into()); - } - } - }; - - logger.on_disconnect(remote_addr, TransportProtocol::WebSocket); - - // Drive all running methods to completion. - // **NOTE** Do not return early in this function. This `await` needs to run to guarantee - // proper drop behaviour. - method_executors.await; - - // Notify all listeners and close down associated tasks. - sink.close(); - bounded_subscriptions.close(); - - drop(conn); - - result -} - -/// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`. -async fn send_task( - mut rx: mpsc::UnboundedReceiver, - mut ws_sender: Sender, - mut stop_handle: StopHandle, - ping_interval: Duration, -) { - // Received messages from the WebSocket. - let mut rx_item = rx.next(); - - // Interval to send out continuously `pings`. - let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval)); - let stopped = stop_handle.shutdown(); - - tokio::pin!(ping_interval, stopped); - - let next_ping = ping_interval.next(); - let mut futs = future::select(next_ping, stopped); - - loop { - // Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish - // yet. Note: Although, this is cancel-safe already, avoid using `select!` macro for - // future proofing. - match future::select(rx_item, futs).await { - // Received message. - Either::Left((Some(response), not_ready)) => { - // If websocket message send fail then terminate the connection. - if let Err(err) = send_message(&mut ws_sender, response).await { - tracing::error!("WS transport error: send failed: {}", err); - break; - } - rx_item = rx.next(); - futs = not_ready; - } - - // Nothing else to receive. - Either::Left((None, _)) => { - break; - } - - // Handle timer intervals. - Either::Right((Either::Left((_, stop)), next_rx)) => { - if let Err(err) = send_ping(&mut ws_sender).await { - tracing::error!("WS transport error: send ping failed: {}", err); - break; - } - rx_item = next_rx; - futs = future::select(ping_interval.next(), stop); - } - - // Server is closed - Either::Right((Either::Right((_, _)), _)) => { - break; - } - } - } - - // Terminate connection and send close message. - let _ = ws_sender.close().await; -} From 9411eb5d6a092e029146d12456555bc898561fa1 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Sat, 7 Dec 2024 10:09:13 -0600 Subject: [PATCH 12/21] Update cargo --- Cargo.lock | 102 +++++++++++++++++++++++++++++------------------------ 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 218bda9505..2ba5e7038f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,7 +40,7 @@ dependencies = [ "serde-wasm-bindgen", "serde_cbor_2", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "sha2 0.10.8", "starknet 0.12.0", "starknet-crypto 0.7.2", @@ -1603,9 +1603,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.82" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", @@ -1943,7 +1943,7 @@ dependencies = [ "bitflags 2.6.0", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -2711,7 +2711,7 @@ dependencies = [ [[package]] name = "cairo-lang-macro" version = "0.1.0" -source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#5ee01a699da7a973c38ba51eac1cb6065bb5006f" +source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#7eac49b3e61236ce466e712225d9c989f9db1ef3" dependencies = [ "cairo-lang-macro-attributes", "cairo-lang-macro-stable", @@ -3694,7 +3694,7 @@ dependencies = [ "ed25519-dalek", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "starknet-types-core", ] @@ -3958,7 +3958,7 @@ dependencies = [ [[package]] name = "create-output-dir" version = "1.0.0" -source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#5ee01a699da7a973c38ba51eac1cb6065bb5006f" +source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#7eac49b3e61236ce466e712225d9c989f9db1ef3" dependencies = [ "anyhow", "core-foundation 0.10.0", @@ -4099,9 +4099,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -4669,7 +4669,7 @@ dependencies = [ "semver 1.0.23", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "smol_str", "starknet 0.12.0", "tempfile", @@ -4681,8 +4681,8 @@ dependencies = [ [[package]] name = "dojo-lang" -version = "1.0.0-rc.0" -source = "git+https://github.com/dojoengine/dojo?rev=479b698def87b004ecc074058182fac40d53d077#479b698def87b004ecc074058182fac40d53d077" +version = "1.0.0-rc.2" +source = "git+https://github.com/dojoengine/dojo?rev=6725a8f20af56213fa7382aa1e158817f3ee623c#6725a8f20af56213fa7382aa1e158817f3ee623c" dependencies = [ "anyhow", "cairo-lang-compiler", @@ -4700,16 +4700,17 @@ dependencies = [ "cairo-lang-utils", "camino", "convert_case 0.6.0", - "dojo-types 1.0.0-rc.0 (git+https://github.com/dojoengine/dojo?rev=479b698def87b004ecc074058182fac40d53d077)", + "dojo-types 1.0.0-rc.2", "indoc 1.0.9", "itertools 0.12.1", "regex", "semver 1.0.23", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "smol_str", "starknet 0.12.0", + "starknet-crypto 0.7.2", "tempfile", "toml 0.8.19", "tracing", @@ -4791,8 +4792,8 @@ dependencies = [ [[package]] name = "dojo-types" -version = "1.0.0-rc.0" -source = "git+https://github.com/dojoengine/dojo?rev=479b698def87b004ecc074058182fac40d53d077#479b698def87b004ecc074058182fac40d53d077" +version = "1.0.0-rc.2" +source = "git+https://github.com/dojoengine/dojo?rev=6725a8f20af56213fa7382aa1e158817f3ee623c#6725a8f20af56213fa7382aa1e158817f3ee623c" dependencies = [ "anyhow", "cainome 0.4.6", @@ -4841,7 +4842,7 @@ dependencies = [ "regex", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "starknet 0.12.0", "starknet-crypto 0.7.2", "thiserror", @@ -7208,7 +7209,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core 0.51.1", + "windows-core 0.52.0", ] [[package]] @@ -8331,7 +8332,7 @@ dependencies = [ "rstest 0.18.2", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "similar-asserts", "starknet 0.12.0", "starknet-crypto 0.7.2", @@ -8447,7 +8448,7 @@ dependencies = [ "rstest 0.18.2", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "starknet 0.12.0", "thiserror", ] @@ -8610,9 +8611,9 @@ dependencies = [ [[package]] name = "lambdaworks-crypto" -version = "0.7.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fb5d4f22241504f7c7b8d2c3a7d7835d7c07117f10bff2a7d96a9ef6ef217c3" +checksum = "bbc2a4da0d9e52ccfe6306801a112e81a8fc0c76aa3e4449fefeda7fef72bb34" dependencies = [ "lambdaworks-math", "serde", @@ -8622,9 +8623,9 @@ dependencies = [ [[package]] name = "lambdaworks-math" -version = "0.7.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "358e172628e713b80a530a59654154bfc45783a6ed70ea284839800cebdf8f97" +checksum = "d1bd2632acbd9957afc5aeec07ad39f078ae38656654043bf16e046fa2730e23" dependencies = [ "serde", "serde_json", @@ -11043,7 +11044,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.10.5", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -11084,7 +11085,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.77", @@ -12440,7 +12441,7 @@ dependencies = [ "num-traits 0.2.19", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "starknet 0.12.0", "thiserror", "tokio", @@ -12451,7 +12452,7 @@ dependencies = [ [[package]] name = "scarb" version = "2.8.4" -source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#5ee01a699da7a973c38ba51eac1cb6065bb5006f" +source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#7eac49b3e61236ce466e712225d9c989f9db1ef3" dependencies = [ "anyhow", "async-trait", @@ -12481,7 +12482,7 @@ dependencies = [ "derive_builder", "dialoguer", "directories", - "dojo-lang 1.0.0-rc.0 (git+https://github.com/dojoengine/dojo?rev=479b698def87b004ecc074058182fac40d53d077)", + "dojo-lang 1.0.0-rc.2", "dunce", "fs4", "fs_extra", @@ -12532,7 +12533,7 @@ dependencies = [ [[package]] name = "scarb-build-metadata" version = "2.8.4" -source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#5ee01a699da7a973c38ba51eac1cb6065bb5006f" +source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#7eac49b3e61236ce466e712225d9c989f9db1ef3" dependencies = [ "cargo_metadata", ] @@ -12553,7 +12554,7 @@ dependencies = [ [[package]] name = "scarb-metadata" version = "1.12.0" -source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#5ee01a699da7a973c38ba51eac1cb6065bb5006f" +source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#7eac49b3e61236ce466e712225d9c989f9db1ef3" dependencies = [ "camino", "derive_builder", @@ -12576,7 +12577,7 @@ dependencies = [ [[package]] name = "scarb-stable-hash" version = "1.0.0" -source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#5ee01a699da7a973c38ba51eac1cb6065bb5006f" +source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#7eac49b3e61236ce466e712225d9c989f9db1ef3" dependencies = [ "data-encoding", "xxhash-rust", @@ -12585,7 +12586,7 @@ dependencies = [ [[package]] name = "scarb-ui" version = "0.1.5" -source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#5ee01a699da7a973c38ba51eac1cb6065bb5006f" +source = "git+https://github.com/dojoengine/scarb?branch=dojo-284#7eac49b3e61236ce466e712225d9c989f9db1ef3" dependencies = [ "anyhow", "camino", @@ -12873,9 +12874,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -12954,9 +12955,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ "base64 0.22.1", "chrono", @@ -12966,7 +12967,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", - "serde_with_macros 3.9.0", + "serde_with_macros 3.11.0", "time", ] @@ -12984,9 +12985,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" dependencies = [ "darling 0.20.10", "proc-macro2", @@ -13428,7 +13429,7 @@ dependencies = [ "num-traits 0.2.19", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "sozo-walnut", "spinoff", "starknet 0.12.0", @@ -13831,7 +13832,7 @@ checksum = "bd6ee5762d24c4f06ab7e9406550925df406712e73719bd2de905c879c674a87" dependencies = [ "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "starknet-accounts 0.11.0", "starknet-core 0.12.0", "starknet-providers 0.12.0", @@ -13870,7 +13871,7 @@ dependencies = [ "serde", "serde_json", "serde_json_pythonic", - "serde_with 3.9.0", + "serde_with 3.11.0", "sha3", "starknet-crypto 0.7.2", "starknet-types-core", @@ -14033,7 +14034,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "starknet-core 0.12.0", "thiserror", "url", @@ -14075,9 +14076,9 @@ dependencies = [ [[package]] name = "starknet-types-core" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b889ee5734db8b3c8a6551135c16764bf4ce1ab4955fffbb2ac5b6706542b64" +checksum = "fa1b9e01ccb217ab6d475c5cda05dbb22c30029f7bb52b192a010a00d77a3d74" dependencies = [ "arbitrary", "lambdaworks-crypto", @@ -16536,6 +16537,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-core" version = "0.57.0" From 8e61aa4521f5f1ec4e787b19204584fb2da7e851 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Thu, 12 Dec 2024 20:18:11 -0600 Subject: [PATCH 13/21] - Accepting unit param - Path is not necesary now when calling the DevnetProxyLayer - Using a match to return back the params and the method to call through rpc call --- Cargo.lock | 27 ---- crates/katana/node/src/lib.rs | 3 +- crates/katana/primitives/src/fee.rs | 13 ++ .../katana/primitives/src/genesis/constant.rs | 11 ++ crates/katana/rpc/rpc-api/src/dev.rs | 7 +- crates/katana/rpc/rpc/src/dev.rs | 31 ++--- .../katana/rpc/rpc/src/proxy_get_request.rs | 116 +++++++----------- 7 files changed, 83 insertions(+), 125 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 62f1cd4b74..78066a3b2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4324,27 +4324,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "csv" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" -dependencies = [ - "csv-core", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "csv-core" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" -dependencies = [ - "memchr", -] - [[package]] name = "ctr" version = "0.9.2" @@ -4920,14 +4899,8 @@ dependencies = [ "cairo-lang-syntax", "cairo-lang-utils", "dojo-types 1.0.5", - "camino", - "convert_case 0.6.0", - "dojo-types 1.0.0-rc.2", - "indoc 1.0.9", "itertools 0.12.1", "serde", - "serde_json", - "serde_with 3.11.0", "smol_str", "starknet 0.12.0", "starknet-crypto 0.7.2", diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index bbb88b64fb..fa6bb4943a 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -325,8 +325,7 @@ pub async fn spawn( let middleware = tower::ServiceBuilder::new() .option_layer(cors) .layer(ProxyGetRequestLayer::new("/", "health")?) - .layer(DevnetProxyLayer::new("/account_balance", "dev_accountBalance")?) - .layer(DevnetProxyLayer::new("/fee_token", "dev_feeToken")?) + .layer(DevnetProxyLayer::new()?) .timeout(Duration::from_secs(20)); let server = ServerBuilder::new() diff --git a/crates/katana/primitives/src/fee.rs b/crates/katana/primitives/src/fee.rs index 11c617da2d..ed332b8944 100644 --- a/crates/katana/primitives/src/fee.rs +++ b/crates/katana/primitives/src/fee.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + #[derive(Debug, Clone, PartialEq, Eq, Default)] #[cfg_attr(feature = "arbitrary", derive(::arbitrary::Arbitrary))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -29,6 +31,17 @@ pub enum PriceUnit { Fri, } +impl FromStr for PriceUnit { + type Err = (); + fn from_str(s: &str) -> Result { + match s { + "WEI" => Ok(PriceUnit::Wei), + "FRI" => Ok(PriceUnit::Fri), + _ => Err(()), // Return an error for unknown units + } + } +} + /// Information regarding the fee and gas usages of a transaction. #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "arbitrary", derive(::arbitrary::Arbitrary))] diff --git a/crates/katana/primitives/src/genesis/constant.rs b/crates/katana/primitives/src/genesis/constant.rs index 44d19b88e7..c4add00537 100644 --- a/crates/katana/primitives/src/genesis/constant.rs +++ b/crates/katana/primitives/src/genesis/constant.rs @@ -1,9 +1,11 @@ +use anyhow::Error; use lazy_static::lazy_static; use starknet::core::utils::get_storage_var_address; use starknet::macros::felt; use crate::class::{ClassHash, CompiledClass, CompiledClassHash, ContractClass}; use crate::contract::{ContractAddress, StorageKey}; +use crate::fee::PriceUnit; use crate::utils::class::{ parse_compiled_class, parse_deprecated_compiled_class, parse_sierra_class, }; @@ -124,6 +126,15 @@ fn read_legacy_class_artifact(artifact: &str) -> ContractClass { ContractClass::Legacy(class) } +pub fn get_erc20_address(unit: &PriceUnit) -> Result { + let erc20_contract_address = match unit { + PriceUnit::Wei => DEFAULT_ETH_FEE_TOKEN_ADDRESS, + PriceUnit::Fri => DEFAULT_STRK_FEE_TOKEN_ADDRESS, + }; + + Ok(ContractAddress::new(erc20_contract_address.into())) +} + #[cfg(test)] mod tests { diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index ad9937f453..49ce1f8d39 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -20,16 +20,13 @@ pub trait DevApi { #[method(name = "setStorageAt")] async fn set_storage_at(&self, contract_address: Felt, key: Felt, value: Felt) - -> RpcResult<()>; + -> RpcResult<()>; #[method(name = "predeployedAccounts")] async fn predeployed_accounts(&self) -> RpcResult>; #[method(name = "accountBalance")] - async fn account_balance(&self, address: String) -> RpcResult; - - #[method(name = "feeToken")] - async fn fee_token(&self) -> RpcResult; + async fn account_balance(&self, address: String, unit: String) -> RpcResult; #[method(name = "mint")] async fn mint(&self) -> RpcResult<()>; diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index 6eb17474bf..c3c16c246d 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -1,12 +1,14 @@ use std::str::FromStr; use std::sync::Arc; +use crate::transport::http; use jsonrpsee::core::{async_trait, Error}; use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::ExecutorFactory; +use katana_primitives::fee::PriceUnit; use katana_primitives::genesis::constant::{ - get_fee_token_balance_base_storage_address, ERC20_NAME_STORAGE_SLOT, + get_erc20_address, get_fee_token_balance_base_storage_address, }; use katana_primitives::ContractAddress; use katana_provider::traits::state::StateFactoryProvider; @@ -97,31 +99,22 @@ impl DevApiServer for DevApi { Ok(()) } - async fn account_balance(&self, address: String) -> Result { - let account_address: ContractAddress = Felt::from_str(&address).unwrap().into(); + async fn account_balance(&self, address: String, unit: String) -> Result { + let account_address: ContractAddress = Felt::from_str(address.as_str()).unwrap().into(); + let unit = Some(PriceUnit::from_str(unit.to_uppercase().as_str())) + .unwrap() + .unwrap_or(PriceUnit::Wei); + let erc20_address = + get_erc20_address(&unit).map_err(|_| http::response::internal_error()).unwrap(); + let provider = self.backend.blockchain.provider(); let state = provider.latest().unwrap(); - // let storage_slot = - // get_storage_var_address("ERC20_balances", &[account_address.into()]).unwrap(); let storage_slot = get_fee_token_balance_base_storage_address(account_address); - let balance_felt = state - .storage(self.backend.chain_spec.fee_contracts.eth, storage_slot) - .unwrap() - .unwrap(); + let balance_felt = state.storage(erc20_address, storage_slot).unwrap().unwrap(); let balance: u128 = balance_felt.to_string().parse().unwrap(); Ok(balance) } - async fn fee_token(&self) -> Result { - let provider = self.backend.blockchain.provider(); - let state = provider.latest().unwrap(); - let fee_token = state - .storage(self.backend.chain_spec.fee_contracts.eth, ERC20_NAME_STORAGE_SLOT) - .unwrap() - .unwrap(); - Ok(fee_token.to_string()) - } - async fn mint(&self) -> Result<(), Error> { Ok(()) } diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index 210e7526cb..26d7ca263a 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -1,19 +1,16 @@ //! Middleware that proxies requests at a specified URI to internal //! RPC method calls. - -use std::collections::HashMap; -use std::error::Error; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - use hyper::header::{ACCEPT, CONTENT_TYPE}; use hyper::http::HeaderValue; use hyper::{Body, Method, Request, Response, Uri}; use jsonrpsee_core::error::Error as RpcError; use jsonrpsee_core::JsonRawValue; use jsonrpsee_types::{Id, RequestSer}; +use std::collections::HashMap; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use tower::{Layer, Service}; use url::form_urlencoded; @@ -24,30 +21,21 @@ use crate::transport::http; /// /// See [`DevnetProxy`] for more details. #[derive(Debug, Clone)] -pub struct DevnetProxyLayer { - path: String, - method: String, -} +pub struct DevnetProxyLayer {} impl DevnetProxyLayer { /// Creates a new [`DevnetProxyLayer`]. /// /// See [`DevnetProxy`] for more details. - pub fn new(path: impl Into, method: impl Into) -> Result { - let path = path.into(); - if !path.starts_with('/') { - return Err(RpcError::Custom("DevnetProxyLayer path must start with `/`".to_string())); - } - - Ok(Self { path, method: method.into() }) + pub fn new() -> Result { + Ok(Self {}) } } impl Layer for DevnetProxyLayer { type Service = DevnetProxy; fn layer(&self, inner: S) -> Self::Service { - DevnetProxy::new(inner, &self.path, &self.method) - .expect("Path already validated in DevnetProxyLayer; qed") + DevnetProxy::new(inner).expect("Path already validated in DevnetProxyLayer; qed") } } @@ -66,8 +54,6 @@ impl Layer for DevnetProxyLayer { #[derive(Debug, Clone)] pub struct DevnetProxy { inner: S, - path: Arc, - method: Arc, } impl DevnetProxy { @@ -75,15 +61,8 @@ impl DevnetProxy { /// /// The request `GET /path` is redirected to the provided method. /// Fails if the path does not start with `/`. - pub fn new(inner: S, path: &str, method: &str) -> Result { - if !path.starts_with('/') { - return Err(RpcError::Custom(format!( - "DevnetProxy path must start with `/`, got: {}", - path - ))); - } - - Ok(Self { inner, path: Arc::from(path), method: Arc::from(method) }) + pub fn new(inner: S) -> Result { + Ok(Self { inner }) } } @@ -105,43 +84,29 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - // let modify = self.path.as_ref() == req.uri() && req.method() == Method::GET; - let modify = req.method() == Method::GET; - - // Proxy the request to the appropriate method call. - if modify { - let mut raw_value = None; - - // If method is dev_accountBalance then get the contract_address query param and assign - // it to raw_value - if self.method.to_string() == "dev_accountBalance".to_string() { - if let Some(query) = req.uri().query() { - let params: HashMap<_, _> = - form_urlencoded::parse(query.as_bytes()).into_owned().collect(); - if let Some(address) = params.get("contract_address") { - let json_string = format!(r#"{{"address":"{}"}}"#, address); - raw_value = Some(JsonRawValue::from_string(json_string).unwrap()); - } - } - } + let query = req.uri().query(); + let path = req.uri().path(); + + let (params, method) = match path { + "/account_balance" => get_account_balance(query), + _ => (JsonRawValue::from_string(String::new()).unwrap(), "".to_string()), + }; - // RPC methods are accessed with `POST`. - *req.method_mut() = Method::POST; - // Precautionary remove the URI. - *req.uri_mut() = Uri::from_static("/"); + // RPC methods are accessed with `POST`. + *req.method_mut() = Method::POST; + // Precautionary remove the URI. + *req.uri_mut() = Uri::from_static("/"); - // Requests must have the following headers: - req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); + // Requests must have the following headers: + req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); - // Adjust the body to reflect the method call. - let param = raw_value.as_ref().map(|value| value.as_ref()); - let body = Body::from( - serde_json::to_string(&RequestSer::borrowed(&Id::Number(0), &self.method, param)) - .expect("Valid request; qed"), - ); - req = req.map(|_| body); - } + // Adjust the body to reflect the method call. + let body = Body::from( + serde_json::to_string(&RequestSer::borrowed(&Id::Number(0), &method, Some(¶ms))) + .expect("Valid request; qed"), + ); + req = req.map(|_| body); // Call the inner service and get a future that resolves to the response. let fut = self.inner.call(req); @@ -149,12 +114,6 @@ where // Adjust the response if needed. let res_fut = async move { let res = fut.await.map_err(|err| err.into())?; - - // Nothing to modify: return the response as is. - if !modify { - return Ok(res); - } - let body = res.into_body(); let bytes = hyper::body::to_bytes(body).await?; @@ -176,3 +135,16 @@ where Box::pin(res_fut) } } + +fn get_account_balance(query: Option<&str>) -> (Box, std::string::String) { + let default = String::new(); + + let query = query.unwrap_or(&default); + let params: HashMap<_, _> = form_urlencoded::parse(query.as_bytes()).into_owned().collect(); + + let address = params.get("contract_address").unwrap_or(&default); + let unit = params.get("unit").unwrap_or(&default); + + let json_string = format!(r#"{{"address":"{}", "unit":"{}"}}"#, address, unit); + (JsonRawValue::from_string(json_string).unwrap(), "dev_accountBalance".to_string()) +} From 01f69c2f1e9dbe571861f9b87b676ce29a718afe Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Thu, 12 Dec 2024 20:52:39 -0600 Subject: [PATCH 14/21] Fix cargo, removing duplicates --- crates/katana/rpc/rpc/Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/crates/katana/rpc/rpc/Cargo.toml b/crates/katana/rpc/rpc/Cargo.toml index 3912fe78cb..8226d631dd 100644 --- a/crates/katana/rpc/rpc/Cargo.toml +++ b/crates/katana/rpc/rpc/Cargo.toml @@ -15,8 +15,6 @@ jsonrpsee = { workspace = true, features = [ "server" ] } jsonrpsee-core = { version = "0.16.3", features = [ "server", "soketto", "http-helpers" ] } jsonrpsee-types = { version = "0.16.3"} hyper.workspace = true -tower = { workspace = true, features = [ "full" ] } -http = { version = "0.2.7" } katana-core.workspace = true katana-executor.workspace = true katana-pool.workspace = true @@ -36,7 +34,6 @@ tower-http.workspace = true tracing.workspace = true url.workspace = true serde.workspace = true -serde_json.workspace = true soketto = { version = "0.7.1", features = ["http"] } futures-channel = { version = "0.3.14"} futures-util = { version = "0.3.14", features = [ From 78c21801689b31dec8357e1228bf96339d7b5967 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Thu, 12 Dec 2024 20:55:37 -0600 Subject: [PATCH 15/21] Implement Devnet Proxy Layer in new implementation --- crates/katana/node/src/lib.rs | 2 -- crates/katana/rpc/rpc/src/lib.rs | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index 9eeb3c1bb5..922586c8da 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -35,8 +35,6 @@ use katana_primitives::block::GasPrices; use katana_primitives::env::{CfgEnv, FeeTokenAddressses}; use katana_rpc::cors::Cors; use katana_rpc::dev::DevApi; -use katana_rpc::metrics::RpcServerMetrics; -use katana_rpc::proxy_get_request::DevnetProxyLayer; use katana_rpc::saya::SayaApi; use katana_rpc::starknet::forking::ForkedClient; use katana_rpc::starknet::{StarknetApi, StarknetApiConfig}; diff --git a/crates/katana/rpc/rpc/src/lib.rs b/crates/katana/rpc/rpc/src/lib.rs index 0bf8213da9..5b34daf128 100644 --- a/crates/katana/rpc/rpc/src/lib.rs +++ b/crates/katana/rpc/rpc/src/lib.rs @@ -8,6 +8,7 @@ use std::time::Duration; use jsonrpsee::server::{AllowHosts, ServerBuilder, ServerHandle}; use jsonrpsee::RpcModule; +use proxy_get_request::DevnetProxyLayer; use tower::ServiceBuilder; use tracing::info; @@ -121,6 +122,7 @@ impl RpcServer { let middleware = ServiceBuilder::new() .option_layer(self.cors.clone()) .option_layer(health_check_proxy) + .layer(DevnetProxyLayer::new()?) .timeout(Duration::from_secs(20)); let builder = ServerBuilder::new() From c8d2444df6223052c98ee15473eb332afa1b3d96 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Fri, 3 Jan 2025 17:51:23 -0600 Subject: [PATCH 16/21] Clean unused code --- crates/katana/rpc/rpc/src/future.rs | 216 --------- crates/katana/rpc/rpc/src/logger.rs | 193 -------- crates/katana/rpc/rpc/src/transport/http.rs | 466 +------------------- 3 files changed, 1 insertion(+), 874 deletions(-) delete mode 100644 crates/katana/rpc/rpc/src/future.rs delete mode 100644 crates/katana/rpc/rpc/src/logger.rs diff --git a/crates/katana/rpc/rpc/src/future.rs b/crates/katana/rpc/rpc/src/future.rs deleted file mode 100644 index 82c28f1ebc..0000000000 --- a/crates/katana/rpc/rpc/src/future.rs +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Utilities for handling async code. - -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use futures_util::future::FutureExt; -use jsonrpsee_core::Error; -use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError}; -use tokio::time::{self, Duration, Interval}; - -/// Polling for server stop monitor interval in milliseconds. -const STOP_MONITOR_POLLING_INTERVAL: Duration = Duration::from_millis(1000); - -/// This is a flexible collection of futures that need to be driven to completion -/// alongside some other future, such as connection handlers that need to be -/// handled along with a listener for new connections. -/// -/// In order to `.await` on these futures and drive them to completion, call -/// `select_with` providing some other future, the result of which you need. -pub(crate) struct FutureDriver { - futures: Vec, - stop_monitor_heartbeat: Interval, -} - -impl Default for FutureDriver { - fn default() -> Self { - let mut heartbeat = time::interval(STOP_MONITOR_POLLING_INTERVAL); - - heartbeat.set_missed_tick_behavior(time::MissedTickBehavior::Skip); - - FutureDriver { futures: Vec::new(), stop_monitor_heartbeat: heartbeat } - } -} - -impl FutureDriver { - /// Add a new future to this driver - pub(crate) fn add(&mut self, future: F) { - self.futures.push(future); - } -} - -impl FutureDriver -where - F: Future + Unpin, -{ - pub(crate) async fn select_with(&mut self, selector: S) -> S::Output { - tokio::pin!(selector); - - DriverSelect { selector, driver: self }.await - } - - fn drive(&mut self, cx: &mut Context) { - let mut i = 0; - - while i < self.futures.len() { - if self.futures[i].poll_unpin(cx).is_ready() { - // Using `swap_remove` since we don't care about ordering - // but we do care about removing being `O(1)`. - // - // We don't increment `i` in this branch, since we now - // have a shorter length, and potentially a new value at - // current index - self.futures.swap_remove(i); - } else { - i += 1; - } - } - } - - fn poll_stop_monitor_heartbeat(&mut self, cx: &mut Context) { - // We don't care about the ticks of the heartbeat, it's here only - // to periodically wake the `Waker` on `cx`. - let _ = self.stop_monitor_heartbeat.poll_tick(cx); - } -} - -impl Future for FutureDriver -where - F: Future + Unpin, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let this = Pin::into_inner(self); - - this.drive(cx); - - if this.futures.is_empty() { Poll::Ready(()) } else { Poll::Pending } - } -} - -/// This is a glorified select `Future` that will attempt to drive all -/// connection futures `F` to completion on each `poll`, while also -/// handling incoming connections. -struct DriverSelect<'a, S, F> { - selector: S, - driver: &'a mut FutureDriver, -} - -impl<'a, R, F> Future for DriverSelect<'a, R, F> -where - R: Future + Unpin, - F: Future + Unpin, -{ - type Output = R::Output; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let this = Pin::into_inner(self); - - this.driver.drive(cx); - this.driver.poll_stop_monitor_heartbeat(cx); - - this.selector.poll_unpin(cx) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct StopHandle(watch::Receiver<()>); - -impl StopHandle { - pub(crate) fn new(rx: watch::Receiver<()>) -> Self { - Self(rx) - } - - pub(crate) fn shutdown_requested(&self) -> bool { - // if a message has been seen, it means that `stop` has been called. - self.0.has_changed().unwrap_or(true) - } - - pub(crate) async fn shutdown(&mut self) { - // Err(_) implies that the `sender` has been dropped. - // Ok(_) implies that `stop` has been called. - let _ = self.0.changed().await; - } -} - -/// Server handle. -/// -/// When all [`StopHandle`]'s have been `dropped` or `stop` has been called -/// the server will be stopped. -#[derive(Debug, Clone)] -pub struct ServerHandle(Arc>); - -impl ServerHandle { - /// Create a new server handle. - pub fn new(tx: watch::Sender<()>) -> Self { - Self(Arc::new(tx)) - } - - /// Tell the server to stop without waiting for the server to stop. - pub fn stop(&self) -> Result<(), Error> { - self.0.send(()).map_err(|_| Error::AlreadyStopped) - } - - /// Wait for the server to stop. - pub async fn stopped(self) { - self.0.closed().await - } - - /// Check if the server has been stopped. - pub fn is_stopped(&self) -> bool { - self.0.is_closed() - } -} - -/// Limits the number of connections. -#[derive(Debug)] -pub(crate) struct ConnectionGuard(Arc); - -impl ConnectionGuard { - pub(crate) fn new(limit: usize) -> Self { - Self(Arc::new(Semaphore::new(limit))) - } - - pub(crate) fn try_acquire(&self) -> Option { - match self.0.clone().try_acquire_owned() { - Ok(guard) => Some(guard), - Err(TryAcquireError::Closed) => { - unreachable!("Semaphore::Close is never called and can't be closed; qed") - } - Err(TryAcquireError::NoPermits) => None, - } - } - - pub(crate) fn available_connections(&self) -> usize { - self.0.available_permits() - } -} diff --git a/crates/katana/rpc/rpc/src/logger.rs b/crates/katana/rpc/rpc/src/logger.rs deleted file mode 100644 index 72fcc48605..0000000000 --- a/crates/katana/rpc/rpc/src/logger.rs +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Logger for `jsonrpsee` servers. - -use std::net::SocketAddr; - -/// HTTP request. -pub type HttpRequest = http::Request; -pub use hyper::Body; -pub use jsonrpsee_types::Params; - -/// The type JSON-RPC v2 call, it can be a subscription, method call or unknown. -#[derive(Debug, Copy, Clone)] -pub enum MethodKind { - /// Subscription Call. - Subscription, - /// Unsubscription Call. - Unsubscription, - /// Method call. - MethodCall, - /// Unknown method. - Unknown, -} - -impl std::fmt::Display for MethodKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = match self { - Self::Subscription => "subscription", - Self::MethodCall => "method call", - Self::Unknown => "unknown", - Self::Unsubscription => "unsubscription", - }; - - write!(f, "{}", s) - } -} - -/// The transport protocol used to send or receive a call or request. -#[derive(Debug, Copy, Clone)] -pub enum TransportProtocol { - /// HTTP transport. - Http, - /// WebSocket transport. - WebSocket, -} - -impl std::fmt::Display for TransportProtocol { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = match self { - Self::Http => "http", - Self::WebSocket => "websocket", - }; - - write!(f, "{}", s) - } -} - -/// Defines a logger specifically for WebSocket connections with callbacks during the RPC request -/// life-cycle. The primary use case for this is to collect timings for a larger metrics collection -/// solution. -/// -/// See the [`ServerBuilder::set_logger`](../../jsonrpsee_server/struct.ServerBuilder.html#method. -/// set_logger) for examples. -pub trait Logger: Send + Sync + Clone + 'static { - /// Intended to carry timestamp of a request, for example `std::time::Instant`. How the trait - /// measures time, if at all, is entirely up to the implementation. - type Instant: std::fmt::Debug + Send + Sync + Copy; - - /// Called when a new client connects - fn on_connect(&self, _remote_addr: SocketAddr, _request: &HttpRequest, _t: TransportProtocol); - - /// Called when a new JSON-RPC request comes to the server. - fn on_request(&self, transport: TransportProtocol) -> Self::Instant; - - /// Called on each JSON-RPC method call, batch requests will trigger `on_call` multiple times. - fn on_call( - &self, - method_name: &str, - params: Params, - kind: MethodKind, - transport: TransportProtocol, - ); - - /// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple - /// times. - fn on_result( - &self, - method_name: &str, - success: bool, - started_at: Self::Instant, - transport: TransportProtocol, - ); - - /// Called once the JSON-RPC request is finished and response is sent to the output buffer. - fn on_response(&self, result: &str, started_at: Self::Instant, transport: TransportProtocol); - - /// Called when a client disconnects - fn on_disconnect(&self, _remote_addr: SocketAddr, transport: TransportProtocol); -} - -impl Logger for () { - type Instant = (); - - fn on_connect(&self, _: SocketAddr, _: &HttpRequest, _p: TransportProtocol) -> Self::Instant {} - - fn on_request(&self, _p: TransportProtocol) -> Self::Instant {} - - fn on_call(&self, _: &str, _: Params, _: MethodKind, _p: TransportProtocol) {} - - fn on_result(&self, _: &str, _: bool, _: Self::Instant, _p: TransportProtocol) {} - - fn on_response(&self, _: &str, _: Self::Instant, _p: TransportProtocol) {} - - fn on_disconnect(&self, _: SocketAddr, _p: TransportProtocol) {} -} - -impl Logger for (A, B) -where - A: Logger, - B: Logger, -{ - type Instant = (A::Instant, B::Instant); - - fn on_connect( - &self, - remote_addr: std::net::SocketAddr, - request: &HttpRequest, - transport: TransportProtocol, - ) { - self.0.on_connect(remote_addr, request, transport); - self.1.on_connect(remote_addr, request, transport); - } - - fn on_request(&self, transport: TransportProtocol) -> Self::Instant { - (self.0.on_request(transport), self.1.on_request(transport)) - } - - fn on_call( - &self, - method_name: &str, - params: Params, - kind: MethodKind, - transport: TransportProtocol, - ) { - self.0.on_call(method_name, params.clone(), kind, transport); - self.1.on_call(method_name, params, kind, transport); - } - - fn on_result( - &self, - method_name: &str, - success: bool, - started_at: Self::Instant, - transport: TransportProtocol, - ) { - self.0.on_result(method_name, success, started_at.0, transport); - self.1.on_result(method_name, success, started_at.1, transport); - } - - fn on_response(&self, result: &str, started_at: Self::Instant, transport: TransportProtocol) { - self.0.on_response(result, started_at.0, transport); - self.1.on_response(result, started_at.1, transport); - } - - fn on_disconnect(&self, remote_addr: SocketAddr, transport: TransportProtocol) { - self.0.on_disconnect(remote_addr, transport); - self.1.on_disconnect(remote_addr, transport); - } -} diff --git a/crates/katana/rpc/rpc/src/transport/http.rs b/crates/katana/rpc/rpc/src/transport/http.rs index a081222486..9bbfce3309 100644 --- a/crates/katana/rpc/rpc/src/transport/http.rs +++ b/crates/katana/rpc/rpc/src/transport/http.rs @@ -1,410 +1,8 @@ -use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::Arc; - -use futures_util::future::Either; -use futures_util::stream::{FuturesOrdered, StreamExt}; -use http::Method; -use jsonrpsee_core::error::GenericTransportError; -use jsonrpsee_core::http_helpers::read_body; -use jsonrpsee_core::server::helpers::{ - prepare_error, BatchResponse, BatchResponseBuilder, MethodResponse, -}; -use jsonrpsee_core::server::resource_limiting::Resources; -use jsonrpsee_core::server::rpc_module::{MethodKind, Methods}; -use jsonrpsee_core::tracing::{rx_log_from_json, tx_log_from_str}; -use jsonrpsee_core::JsonRawValue; -use jsonrpsee_types::error::{ErrorCode, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG}; -use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Params, Request}; -use tokio::sync::OwnedSemaphorePermit; -use tracing::instrument; - -use crate::logger::{self, Logger, TransportProtocol}; - -type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>; - -/// Checks that content type of received request is valid for JSON-RPC. -pub(crate) fn content_type_is_json(request: &hyper::Request) -> bool { - is_json(request.headers().get(http::header::CONTENT_TYPE)) -} - -/// Returns true if the `content_type` header indicates a valid JSON message. -pub(crate) fn is_json(content_type: Option<&hyper::header::HeaderValue>) -> bool { - content_type.and_then(|val| val.to_str().ok()).map_or(false, |content| { - content.eq_ignore_ascii_case("application/json") - || content.eq_ignore_ascii_case("application/json; charset=utf-8") - || content.eq_ignore_ascii_case("application/json;charset=utf-8") - }) -} - -pub(crate) async fn reject_connection(socket: tokio::net::TcpStream) { - async fn reject( - _req: hyper::Request, - ) -> Result, Infallible> { - Ok(response::too_many_requests()) - } - - if let Err(e) = hyper::server::conn::Http::new() - .serve_connection(socket, hyper::service::service_fn(reject)) - .await - { - tracing::warn!("Error when trying to deny connection: {:?}", e); - } -} - -#[derive(Debug)] -pub(crate) struct ProcessValidatedRequest<'a, L: Logger> { - pub(crate) request: hyper::Request, - pub(crate) logger: &'a L, - pub(crate) methods: Methods, - pub(crate) resources: Resources, - pub(crate) max_request_body_size: u32, - pub(crate) max_response_body_size: u32, - pub(crate) max_log_length: u32, - pub(crate) batch_requests_supported: bool, - pub(crate) request_start: L::Instant, -} - -/// Process a verified request, it implies a POST request with content type JSON. -pub(crate) async fn process_validated_request( - input: ProcessValidatedRequest<'_, L>, -) -> hyper::Response { - let ProcessValidatedRequest { - request, - logger, - methods, - resources, - max_request_body_size, - max_response_body_size, - max_log_length, - batch_requests_supported, - request_start, - } = input; - - let (parts, body) = request.into_parts(); - - let (body, is_single) = match read_body(&parts.headers, body, max_request_body_size).await { - Ok(r) => r, - Err(GenericTransportError::TooLarge) => return response::too_large(max_request_body_size), - Err(GenericTransportError::Malformed) => return response::malformed(), - Err(GenericTransportError::Inner(e)) => { - tracing::error!("Internal error reading request body: {}", e); - return response::internal_error(); - } - }; - - // Single request or notification - if is_single { - let call = CallData { - conn_id: 0, - logger, - methods: &methods, - max_response_body_size, - max_log_length, - resources: &resources, - request_start, - }; - let response = process_single_request(body, call).await; - logger.on_response(&response.result, request_start, TransportProtocol::Http); - response::ok_response(response.result) - } - // Batch of requests or notifications - else if !batch_requests_supported { - let err = MethodResponse::error( - Id::Null, - ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None), - ); - logger.on_response(&err.result, request_start, TransportProtocol::Http); - response::ok_response(err.result) - } - // Batch of requests or notifications - else { - let response = process_batch_request(Batch { - data: body, - call: CallData { - conn_id: 0, - logger, - methods: &methods, - max_response_body_size, - max_log_length, - resources: &resources, - request_start, - }, - }) - .await; - logger.on_response(&response.result, request_start, TransportProtocol::Http); - response::ok_response(response.result) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct Batch<'a, L: Logger> { - data: Vec, - call: CallData<'a, L>, -} - -#[derive(Debug, Clone)] -pub(crate) struct CallData<'a, L: Logger> { - conn_id: usize, - logger: &'a L, - methods: &'a Methods, - max_response_body_size: u32, - max_log_length: u32, - resources: &'a Resources, - request_start: L::Instant, -} - -// Batch responses must be sent back as a single message so we read the results from each -// request in the batch and read the results off of a new channel, `rx_batch`, and then send the -// complete batch response back to the client over `tx`. -#[instrument(name = "batch", skip(b), level = "TRACE")] -pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> BatchResponse -where - L: Logger, -{ - let Batch { data, call } = b; - - if let Ok(batch) = serde_json::from_slice::>(&data) { - let mut got_notif = false; - let mut batch_response = - BatchResponseBuilder::new_with_limit(call.max_response_body_size as usize); - - let mut pending_calls: FuturesOrdered<_> = batch - .into_iter() - .filter_map(|v| { - if let Ok(req) = serde_json::from_str::(v.get()) { - Some(Either::Right(execute_call(req, call.clone()))) - } else if let Ok(_notif) = - serde_json::from_str::>(v.get()) - { - // notifications should not be answered. - got_notif = true; - None - } else { - // valid JSON but could be not parsable as `InvalidRequest` - let id = match serde_json::from_str::(v.get()) { - Ok(err) => err.id, - Err(_) => Id::Null, - }; - - Some(Either::Left(async { - MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)) - })) - } - }) - .collect(); - - while let Some(response) = pending_calls.next().await { - if let Err(too_large) = batch_response.append(&response) { - return too_large; - } - } - - if got_notif && batch_response.is_empty() { - BatchResponse { result: String::new(), success: true } - } else { - batch_response.finish() - } - } else { - BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError)) - } -} - -pub(crate) async fn process_single_request( - data: Vec, - call: CallData<'_, L>, -) -> MethodResponse { - if let Ok(req) = serde_json::from_slice::(&data) { - execute_call_with_tracing(req, call).await - } else if let Ok(notif) = serde_json::from_slice::(&data) { - execute_notification(notif, call.max_log_length) - } else { - let (id, code) = prepare_error(&data); - MethodResponse::error(id, ErrorObject::from(code)) - } -} - -#[instrument(name = "method_call", fields(method = req.method.as_ref()), skip(call, req), level = "TRACE")] -pub(crate) async fn execute_call_with_tracing<'a, L: Logger>( - req: Request<'a>, - call: CallData<'_, L>, -) -> MethodResponse { - execute_call(req, call).await -} - -pub(crate) async fn execute_call( - req: Request<'_>, - call: CallData<'_, L>, -) -> MethodResponse { - let CallData { - resources, - methods, - logger, - max_response_body_size, - max_log_length, - conn_id, - request_start, - } = call; - - rx_log_from_json(&req, call.max_log_length); - - let params = Params::new(req.params.map(|params| params.get())); - let name = &req.method; - let id = req.id; - - let response = match methods.method_with_name(name) { - None => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::Unknown, - TransportProtocol::Http, - ); - MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)) - } - Some((name, method)) => match &method.inner() { - MethodKind::Sync(callback) => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::MethodCall, - TransportProtocol::Http, - ); - - match method.claim(name, resources) { - Ok(guard) => { - let r = (callback)(id, params, max_response_body_size as usize); - drop(guard); - r - } - Err(err) => { - tracing::error!( - "[Methods::execute_with_resources] failed to lock resources: {}", - err - ); - MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)) - } - } - } - MethodKind::Async(callback) => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::MethodCall, - TransportProtocol::Http, - ); - match method.claim(name, resources) { - Ok(guard) => { - let id = id.into_owned(); - let params = params.into_owned(); - - (callback)( - id, - params, - conn_id, - max_response_body_size as usize, - Some(guard), - ) - .await - } - Err(err) => { - tracing::error!( - "[Methods::execute_with_resources] failed to lock resources: {}", - err - ); - MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy)) - } - } - } - MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => { - logger.on_call( - name, - params.clone(), - logger::MethodKind::Unknown, - TransportProtocol::Http, - ); - tracing::error!("Subscriptions not supported on HTTP"); - MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)) - } - }, - }; - - tx_log_from_str(&response.result, max_log_length); - logger.on_result(name, response.success, request_start, TransportProtocol::Http); - response -} - -#[instrument(name = "notification", fields(method = notif.method.as_ref()), skip(notif, max_log_length), level = "TRACE")] -fn execute_notification(notif: Notif, max_log_length: u32) -> MethodResponse { - rx_log_from_json(¬if, max_log_length); - let response = MethodResponse { result: String::new(), success: true }; - tx_log_from_str(&response.result, max_log_length); - response -} - -pub(crate) struct HandleRequest { - pub(crate) methods: Methods, - pub(crate) resources: Resources, - pub(crate) max_request_body_size: u32, - pub(crate) max_response_body_size: u32, - pub(crate) max_log_length: u32, - pub(crate) batch_requests_supported: bool, - pub(crate) logger: L, - pub(crate) conn: Arc, - pub(crate) remote_addr: SocketAddr, -} - -pub(crate) async fn handle_request( - request: hyper::Request, - input: HandleRequest, -) -> hyper::Response { - let HandleRequest { - methods, - resources, - max_request_body_size, - max_response_body_size, - max_log_length, - batch_requests_supported, - logger, - conn, - remote_addr, - } = input; - - let request_start = logger.on_request(TransportProtocol::Http); - - // Only the `POST` method is allowed. - let res = match *request.method() { - Method::POST if content_type_is_json(&request) => { - process_validated_request(ProcessValidatedRequest { - request, - methods, - resources, - max_request_body_size, - max_response_body_size, - max_log_length, - batch_requests_supported, - logger: &logger, - request_start, - }) - .await - } - // Error scenarios: - Method::POST => response::unsupported_content_type(), - _ => response::method_not_allowed(), - }; - - drop(conn); - logger.on_disconnect(remote_addr, TransportProtocol::Http); - - res -} - pub(crate) mod response { - use jsonrpsee_types::error::{reject_too_big_request, ErrorCode, ErrorResponse}; + use jsonrpsee_types::error::{ErrorCode, ErrorResponse}; use jsonrpsee_types::Id; const JSON: &str = "application/json; charset=utf-8"; - const TEXT: &str = "text/plain"; /// Create a response for json internal error. pub(crate) fn internal_error() -> hyper::Response { @@ -417,44 +15,6 @@ pub(crate) mod response { from_template(hyper::StatusCode::INTERNAL_SERVER_ERROR, error, JSON) } - /// Create a text/plain response for not allowed hosts. - pub(crate) fn host_not_allowed() -> hyper::Response { - from_template( - hyper::StatusCode::FORBIDDEN, - "Provided Host header is not whitelisted.\n".to_owned(), - TEXT, - ) - } - - /// Create a text/plain response for disallowed method used. - pub(crate) fn method_not_allowed() -> hyper::Response { - from_template( - hyper::StatusCode::METHOD_NOT_ALLOWED, - "Used HTTP Method is not allowed. POST or OPTIONS is required\n".to_owned(), - TEXT, - ) - } - - /// Create a json response for oversized requests (413) - pub(crate) fn too_large(limit: u32) -> hyper::Response { - let error = serde_json::to_string(&ErrorResponse::borrowed( - reject_too_big_request(limit), - Id::Null, - )) - .expect("built from known-good data; qed"); - - from_template(hyper::StatusCode::PAYLOAD_TOO_LARGE, error, JSON) - } - - /// Create a json response for empty or malformed requests (400) - pub(crate) fn malformed() -> hyper::Response { - let error = - serde_json::to_string(&ErrorResponse::borrowed(ErrorCode::ParseError.into(), Id::Null)) - .expect("built from known-good data; qed"); - - from_template(hyper::StatusCode::BAD_REQUEST, error, JSON) - } - /// Create a response body. fn from_template>( status: hyper::StatusCode, @@ -474,28 +34,4 @@ pub(crate) mod response { pub(crate) fn ok_response(body: String) -> hyper::Response { from_template(hyper::StatusCode::OK, body, JSON) } - - /// Create a response for unsupported content type. - pub(crate) fn unsupported_content_type() -> hyper::Response { - from_template( - hyper::StatusCode::UNSUPPORTED_MEDIA_TYPE, - "Supplied content type is not allowed. Content-Type: application/json is required\n" - .to_owned(), - TEXT, - ) - } - - /// Create a response for when the server is busy and can't accept more requests. - pub(crate) fn too_many_requests() -> hyper::Response { - from_template( - hyper::StatusCode::TOO_MANY_REQUESTS, - "Too many connections. Please try again later.".to_owned(), - TEXT, - ) - } - - /// Create a response for when the server denied the request. - pub(crate) fn denied() -> hyper::Response { - from_template(hyper::StatusCode::FORBIDDEN, "".to_owned(), TEXT) - } } From 2cb90f1f576708de3e5499843f22ca97a4da6731 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Fri, 3 Jan 2025 17:59:07 -0600 Subject: [PATCH 17/21] clippy.sh --- crates/katana/rpc/rpc/src/dev.rs | 5 ++--- crates/katana/rpc/rpc/src/lib.rs | 7 ------- crates/katana/rpc/rpc/src/proxy_get_request.rs | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index c3c16c246d..499ea223d0 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -101,9 +101,8 @@ impl DevApiServer for DevApi { async fn account_balance(&self, address: String, unit: String) -> Result { let account_address: ContractAddress = Felt::from_str(address.as_str()).unwrap().into(); - let unit = Some(PriceUnit::from_str(unit.to_uppercase().as_str())) - .unwrap() - .unwrap_or(PriceUnit::Wei); + let unit = PriceUnit::from_str(unit.to_uppercase().as_str()).unwrap_or(PriceUnit::Wei); + let erc20_address = get_erc20_address(&unit).map_err(|_| http::response::internal_error()).unwrap(); diff --git a/crates/katana/rpc/rpc/src/lib.rs b/crates/katana/rpc/rpc/src/lib.rs index 5b34daf128..56de7c6202 100644 --- a/crates/katana/rpc/rpc/src/lib.rs +++ b/crates/katana/rpc/rpc/src/lib.rs @@ -1,8 +1,3 @@ -//! RPC implementations. - -#![allow(clippy::blocks_in_conditions)] -#![cfg_attr(not(test), warn(unused_crate_dependencies))] - use std::net::SocketAddr; use std::time::Duration; @@ -21,8 +16,6 @@ pub mod saya; pub mod starknet; pub mod torii; -mod future; -mod logger; mod transport; mod utils; diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index 26d7ca263a..90d8a5d7a1 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -123,7 +123,7 @@ where result: &'a serde_json::value::RawValue, } - let response = if let Ok(payload) = serde_json::from_slice::(&bytes) { + let response = if let Ok(payload) = serde_json::from_slice::>(&bytes) { http::response::ok_response(payload.result.to_string()) } else { http::response::internal_error() From 0c14c0366e0a13fe361bdaa69f216d5e7e990fc2 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Fri, 3 Jan 2025 18:00:26 -0600 Subject: [PATCH 18/21] rust_fmt.sh --- crates/katana/rpc/rpc-api/src/dev.rs | 2 +- crates/katana/rpc/rpc/src/dev.rs | 3 ++- crates/katana/rpc/rpc/src/proxy_get_request.rs | 11 ++++++----- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/crates/katana/rpc/rpc-api/src/dev.rs b/crates/katana/rpc/rpc-api/src/dev.rs index 49ce1f8d39..94441fb567 100644 --- a/crates/katana/rpc/rpc-api/src/dev.rs +++ b/crates/katana/rpc/rpc-api/src/dev.rs @@ -20,7 +20,7 @@ pub trait DevApi { #[method(name = "setStorageAt")] async fn set_storage_at(&self, contract_address: Felt, key: Felt, value: Felt) - -> RpcResult<()>; + -> RpcResult<()>; #[method(name = "predeployedAccounts")] async fn predeployed_accounts(&self) -> RpcResult>; diff --git a/crates/katana/rpc/rpc/src/dev.rs b/crates/katana/rpc/rpc/src/dev.rs index 499ea223d0..1f6b05b7fe 100644 --- a/crates/katana/rpc/rpc/src/dev.rs +++ b/crates/katana/rpc/rpc/src/dev.rs @@ -1,7 +1,6 @@ use std::str::FromStr; use std::sync::Arc; -use crate::transport::http; use jsonrpsee::core::{async_trait, Error}; use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; @@ -17,6 +16,8 @@ use katana_rpc_types::account::Account; use katana_rpc_types::error::dev::DevApiError; use starknet_crypto::Felt; +use crate::transport::http; + #[allow(missing_debug_implementations)] pub struct DevApi { backend: Arc>, diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index 90d8a5d7a1..97b2c38857 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -1,16 +1,17 @@ //! Middleware that proxies requests at a specified URI to internal //! RPC method calls. +use std::collections::HashMap; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use hyper::header::{ACCEPT, CONTENT_TYPE}; use hyper::http::HeaderValue; use hyper::{Body, Method, Request, Response, Uri}; use jsonrpsee_core::error::Error as RpcError; use jsonrpsee_core::JsonRawValue; use jsonrpsee_types::{Id, RequestSer}; -use std::collections::HashMap; -use std::error::Error; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; use tower::{Layer, Service}; use url::form_urlencoded; From 3d73532acb27952cb50712353f921459bae38533 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Fri, 3 Jan 2025 19:43:13 -0600 Subject: [PATCH 19/21] Fix default case in match and early return when invalid route provided --- crates/katana/rpc/rpc/src/proxy_get_request.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index 97b2c38857..32c57db1ad 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -6,6 +6,7 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use futures::future; use hyper::header::{ACCEPT, CONTENT_TYPE}; use hyper::http::HeaderValue; use hyper::{Body, Method, Request, Response, Uri}; @@ -90,9 +91,13 @@ where let (params, method) = match path { "/account_balance" => get_account_balance(query), - _ => (JsonRawValue::from_string(String::new()).unwrap(), "".to_string()), + _ => (JsonRawValue::from_string("{}".to_string()).unwrap(), "".to_string()), }; + if method.is_empty() { + return Box::pin(future::ok(http::response::ok_response("Unknown route".to_string()))); + } + // RPC methods are accessed with `POST`. *req.method_mut() = Method::POST; // Precautionary remove the URI. From 17a53b54f6650bae455988a117fc40d32b9f96d9 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Fri, 3 Jan 2025 23:09:49 -0600 Subject: [PATCH 20/21] Fix error when returning response when route is not known in DevnetProxyLayer --- .../katana/rpc/rpc/src/proxy_get_request.rs | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index 32c57db1ad..b771e59d54 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -6,7 +6,6 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use futures::future; use hyper::header::{ACCEPT, CONTENT_TYPE}; use hyper::http::HeaderValue; use hyper::{Body, Method, Request, Response, Uri}; @@ -94,25 +93,27 @@ where _ => (JsonRawValue::from_string("{}".to_string()).unwrap(), "".to_string()), }; - if method.is_empty() { - return Box::pin(future::ok(http::response::ok_response("Unknown route".to_string()))); - } - - // RPC methods are accessed with `POST`. - *req.method_mut() = Method::POST; - // Precautionary remove the URI. - *req.uri_mut() = Uri::from_static("/"); - - // Requests must have the following headers: - req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); - - // Adjust the body to reflect the method call. - let body = Body::from( - serde_json::to_string(&RequestSer::borrowed(&Id::Number(0), &method, Some(¶ms))) + if !method.is_empty() { + // RPC methods are accessed with `POST`. + *req.method_mut() = Method::POST; + // Precautionary remove the URI. + *req.uri_mut() = Uri::from_static("/"); + + // Requests must have the following headers: + req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); + + // Adjust the body to reflect the method call. + let body = Body::from( + serde_json::to_string(&RequestSer::borrowed( + &Id::Number(0), + &method, + Some(¶ms), + )) .expect("Valid request; qed"), - ); - req = req.map(|_| body); + ); + req = req.map(|_| body); + } // Call the inner service and get a future that resolves to the response. let fut = self.inner.call(req); @@ -120,6 +121,11 @@ where // Adjust the response if needed. let res_fut = async move { let res = fut.await.map_err(|err| err.into())?; + + if method.is_empty() { + return Ok(res); + } + let body = res.into_body(); let bytes = hyper::body::to_bytes(body).await?; From 5c6981bf24b5528068b04b6181d6ae105c82bb27 Mon Sep 17 00:00:00 2001 From: Fabricio Robles Date: Wed, 8 Jan 2025 20:22:39 -0600 Subject: [PATCH 21/21] Apply code rabbit suggestions --- crates/katana/rpc/rpc/src/proxy_get_request.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/katana/rpc/rpc/src/proxy_get_request.rs b/crates/katana/rpc/rpc/src/proxy_get_request.rs index b771e59d54..e91b282d83 100644 --- a/crates/katana/rpc/rpc/src/proxy_get_request.rs +++ b/crates/katana/rpc/rpc/src/proxy_get_request.rs @@ -158,5 +158,9 @@ fn get_account_balance(query: Option<&str>) -> (Box, std::string:: let unit = params.get("unit").unwrap_or(&default); let json_string = format!(r#"{{"address":"{}", "unit":"{}"}}"#, address, unit); - (JsonRawValue::from_string(json_string).unwrap(), "dev_accountBalance".to_string()) + let raw_value = match JsonRawValue::from_string(json_string) { + Ok(val) => val, + Err(_) => JsonRawValue::from_string(r#"{}"#.to_string()).unwrap(), + }; + (raw_value, "dev_accountBalance".to_string()) }