diff --git a/Cargo.lock b/Cargo.lock index c34d601928..35a24d284c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8674,6 +8674,7 @@ dependencies = [ "starknet 0.12.0", "strum 0.25.0", "strum_macros 0.25.3", + "thiserror 1.0.63", "tokio", "tower 0.4.13", "tower-http 0.4.4", diff --git a/crates/dojo/test-utils/src/sequencer.rs b/crates/dojo/test-utils/src/sequencer.rs index 6b4bdb8ee1..ef1b023eec 100644 --- a/crates/dojo/test-utils/src/sequencer.rs +++ b/crates/dojo/test-utils/src/sequencer.rs @@ -1,16 +1,16 @@ -use std::collections::HashSet; use std::sync::Arc; use katana_core::backend::Backend; use katana_core::constants::DEFAULT_SEQUENCER_ADDRESS; use katana_executor::implementation::blockifier::BlockifierFactory; use katana_node::config::dev::DevConfig; -use katana_node::config::rpc::{ApiKind, RpcConfig, DEFAULT_RPC_ADDR, DEFAULT_RPC_MAX_CONNECTIONS}; +use katana_node::config::rpc::{RpcConfig, DEFAULT_RPC_ADDR, DEFAULT_RPC_MAX_CONNECTIONS}; pub use katana_node::config::*; use katana_node::LaunchedNode; use katana_primitives::chain::ChainId; use katana_primitives::chain_spec::ChainSpec; use katana_rpc::Error; +use rpc::RpcModulesList; use starknet::accounts::{ExecutionEncoding, SingleOwnerAccount}; use starknet::core::chain_id; use starknet::core::types::{BlockId, BlockTag, Felt}; @@ -122,8 +122,8 @@ pub fn get_default_test_config(sequencing: SequencingConfig) -> Config { cors_origins: Vec::new(), port: 0, addr: DEFAULT_RPC_ADDR, + apis: RpcModulesList::all(), max_connections: DEFAULT_RPC_MAX_CONNECTIONS, - apis: HashSet::from([ApiKind::Starknet, ApiKind::Dev, ApiKind::Saya, ApiKind::Torii]), max_event_page_size: Some(100), max_proof_keys: Some(100), }; diff --git a/crates/katana/cli/src/args.rs b/crates/katana/cli/src/args.rs index 2f9f44f9b2..cd57acb1b2 100644 --- a/crates/katana/cli/src/args.rs +++ b/crates/katana/cli/src/args.rs @@ -1,11 +1,10 @@ //! Katana node CLI options and configuration. -use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; use alloy_primitives::U256; -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use clap::Parser; use katana_core::constants::DEFAULT_SEQUENCER_ADDRESS; use katana_core::service::messaging::MessagingConfig; @@ -14,7 +13,7 @@ use katana_node::config::dev::{DevConfig, FixedL1GasPriceConfig}; use katana_node::config::execution::ExecutionConfig; use katana_node::config::fork::ForkingConfig; use katana_node::config::metrics::MetricsConfig; -use katana_node::config::rpc::{ApiKind, RpcConfig}; +use katana_node::config::rpc::{RpcConfig, RpcModuleKind, RpcModulesList}; use katana_node::config::{Config, SequencingConfig}; use katana_primitives::chain_spec::{self, ChainSpec}; use katana_primitives::genesis::allocation::DevAllocationsGenerator; @@ -169,7 +168,7 @@ impl NodeArgs { pub fn config(&self) -> Result { let db = self.db_config(); - let rpc = self.rpc_config(); + let rpc = self.rpc_config()?; let dev = self.dev_config(); let chain = self.chain_spec()?; let metrics = self.metrics_config(); @@ -197,29 +196,39 @@ impl NodeArgs { SequencingConfig { block_time: self.block_time, no_mining: self.no_mining } } - fn rpc_config(&self) -> RpcConfig { - let mut apis = HashSet::from([ApiKind::Starknet, ApiKind::Torii, ApiKind::Saya]); - // only enable `katana` API in dev mode - if self.development.dev { - apis.insert(ApiKind::Dev); - } + fn rpc_config(&self) -> Result { + let modules = if let Some(modules) = &self.server.http_modules { + // TODO: This check should be handled in the `katana-node` level. Right now if you + // instantiate katana programmatically, you can still add the dev module without + // enabling dev mode. + // + // We only allow the `dev` module in dev mode (ie `--dev` flag) + if !self.development.dev && modules.contains(&RpcModuleKind::Dev) { + bail!("The `dev` module can only be enabled in dev mode (ie `--dev` flag)") + } + + modules.clone() + } else { + // Expose the default modules if none is specified. + RpcModulesList::default() + }; #[cfg(feature = "server")] { - RpcConfig { - apis, + Ok(RpcConfig { + apis: modules, port: self.server.http_port, addr: self.server.http_addr, max_connections: self.server.max_connections, cors_origins: self.server.http_cors_origins.clone(), max_event_page_size: Some(self.server.max_event_page_size), max_proof_keys: Some(self.server.max_proof_keys), - } + }) } #[cfg(not(feature = "server"))] { - RpcConfig { apis, ..Default::default() } + Ok(RpcConfig { apis, ..Default::default() }) } } @@ -636,4 +645,28 @@ chain_id.Named = "Mainnet" assert!(cors_origins.contains(&HeaderValue::from_static("http://localhost:3000"))); assert!(cors_origins.contains(&HeaderValue::from_static("https://example.com"))); } + + #[test] + fn http_modules() { + // If the `--http.api` isn't specified, only starknet module will be exposed. + let config = NodeArgs::parse_from(["katana"]).config().unwrap(); + let modules = config.rpc.apis; + assert_eq!(modules.len(), 1); + assert!(modules.contains(&RpcModuleKind::Starknet)); + + // If the `--http.api` is specified, only the ones in the list will be exposed. + let config = NodeArgs::parse_from(["katana", "--http.api", "saya,torii"]).config().unwrap(); + let modules = config.rpc.apis; + assert_eq!(modules.len(), 2); + assert!(modules.contains(&RpcModuleKind::Saya)); + assert!(modules.contains(&RpcModuleKind::Torii)); + + // Specifiying the dev module without enabling dev mode is forbidden. + let err = + NodeArgs::parse_from(["katana", "--http.api", "starknet,dev"]).config().unwrap_err(); + assert!( + err.to_string() + .contains("The `dev` module can only be enabled in dev mode (ie `--dev` flag)") + ); + } } diff --git a/crates/katana/cli/src/options.rs b/crates/katana/cli/src/options.rs index 2d443f51fd..88adabbd9b 100644 --- a/crates/katana/cli/src/options.rs +++ b/crates/katana/cli/src/options.rs @@ -12,7 +12,7 @@ use std::net::IpAddr; use clap::Args; use katana_node::config::execution::{DEFAULT_INVOCATION_MAX_STEPS, DEFAULT_VALIDATION_MAX_STEPS}; use katana_node::config::metrics::{DEFAULT_METRICS_ADDR, DEFAULT_METRICS_PORT}; -use katana_node::config::rpc::DEFAULT_RPC_MAX_PROOF_KEYS; +use katana_node::config::rpc::{RpcModulesList, DEFAULT_RPC_MAX_PROOF_KEYS}; #[cfg(feature = "server")] use katana_node::config::rpc::{ DEFAULT_RPC_ADDR, DEFAULT_RPC_MAX_CONNECTIONS, DEFAULT_RPC_MAX_EVENT_PAGE_SIZE, @@ -97,6 +97,12 @@ pub struct ServerOptions { )] pub http_cors_origins: Vec, + /// API's offered over the HTTP-RPC interface. + #[arg(long = "http.api", value_name = "MODULES")] + #[arg(value_parser = RpcModulesList::parse)] + #[serde(default)] + pub http_modules: Option, + /// Maximum number of concurrent connections allowed. #[arg(long = "rpc.max-connections", value_name = "COUNT")] #[arg(default_value_t = DEFAULT_RPC_MAX_CONNECTIONS)] @@ -122,8 +128,9 @@ impl Default for ServerOptions { ServerOptions { http_addr: DEFAULT_RPC_ADDR, http_port: DEFAULT_RPC_PORT, - max_connections: DEFAULT_RPC_MAX_CONNECTIONS, http_cors_origins: Vec::new(), + http_modules: Some(RpcModulesList::default()), + max_connections: DEFAULT_RPC_MAX_CONNECTIONS, max_event_page_size: DEFAULT_RPC_MAX_EVENT_PAGE_SIZE, max_proof_keys: DEFAULT_RPC_MAX_PROOF_KEYS, } diff --git a/crates/katana/node/Cargo.toml b/crates/katana/node/Cargo.toml index 27c4905a8e..aa47c61777 100644 --- a/crates/katana/node/Cargo.toml +++ b/crates/katana/node/Cargo.toml @@ -25,6 +25,7 @@ jsonrpsee.workspace = true serde.workspace = true serde_json.workspace = true starknet.workspace = true +thiserror.workspace = true tower = { workspace = true, features = [ "full" ] } tower-http = { workspace = true, features = [ "full" ] } tracing.workspace = true diff --git a/crates/katana/node/src/config/rpc.rs b/crates/katana/node/src/config/rpc.rs index 5cc81b538d..9093c8ad0f 100644 --- a/crates/katana/node/src/config/rpc.rs +++ b/crates/katana/node/src/config/rpc.rs @@ -2,6 +2,7 @@ use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use katana_rpc::cors::HeaderValue; +use serde::{Deserialize, Serialize}; /// The default maximum number of concurrent RPC connections. pub const DEFAULT_RPC_MAX_CONNECTIONS: u32 = 100; @@ -13,15 +14,25 @@ pub const DEFAULT_RPC_MAX_EVENT_PAGE_SIZE: u64 = 1024; /// Default maximmum number of keys for the `starknet_getStorageProof` RPC method. pub const DEFAULT_RPC_MAX_PROOF_KEYS: u64 = 100; -/// List of APIs supported by Katana. +/// List of RPC modules supported by Katana. #[derive( - Debug, Copy, Clone, PartialEq, Eq, Hash, strum_macros::EnumString, strum_macros::Display, + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + strum_macros::EnumString, + strum_macros::Display, + Serialize, + Deserialize, )] -pub enum ApiKind { +#[strum(ascii_case_insensitive)] +pub enum RpcModuleKind { Starknet, Torii, - Dev, Saya, + Dev, } /// Configuration for the RPC server. @@ -30,7 +41,7 @@ pub struct RpcConfig { pub addr: IpAddr, pub port: u16, pub max_connections: u32, - pub apis: HashSet, + pub apis: RpcModulesList, pub cors_origins: Vec, pub max_event_page_size: Option, pub max_proof_keys: Option, @@ -49,10 +60,126 @@ impl Default for RpcConfig { cors_origins: Vec::new(), addr: DEFAULT_RPC_ADDR, port: DEFAULT_RPC_PORT, + apis: RpcModulesList::default(), max_connections: DEFAULT_RPC_MAX_CONNECTIONS, - apis: HashSet::from([ApiKind::Starknet]), max_event_page_size: Some(DEFAULT_RPC_MAX_EVENT_PAGE_SIZE), max_proof_keys: Some(DEFAULT_RPC_MAX_PROOF_KEYS), } } } + +#[derive(Debug, thiserror::Error)] +#[error("invalid module: {0}")] +pub struct InvalidRpcModuleError(String); + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[serde(transparent)] +pub struct RpcModulesList(HashSet); + +impl RpcModulesList { + /// Creates an empty modules list. + pub fn new() -> Self { + Self(HashSet::new()) + } + + /// Creates a list with all the possible modules. + pub fn all() -> Self { + Self(HashSet::from([ + RpcModuleKind::Starknet, + RpcModuleKind::Torii, + RpcModuleKind::Saya, + RpcModuleKind::Dev, + ])) + } + + /// Adds a `module` to the list. + pub fn add(&mut self, module: RpcModuleKind) { + self.0.insert(module); + } + + /// Returns `true` if the list contains the specified `module`. + pub fn contains(&self, module: &RpcModuleKind) -> bool { + self.0.contains(module) + } + + /// Returns the number of modules in the list. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns `true` if the list contains no modules. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Used as the value parser for `clap`. + pub fn parse(value: &str) -> Result { + if value.is_empty() { + return Ok(Self::new()); + } + + let mut modules = HashSet::new(); + for module_str in value.split(',') { + let module: RpcModuleKind = module_str + .trim() + .parse() + .map_err(|_| InvalidRpcModuleError(module_str.to_string()))?; + + modules.insert(module); + } + + Ok(Self(modules)) + } +} + +impl Default for RpcModulesList { + fn default() -> Self { + Self(HashSet::from([RpcModuleKind::Starknet])) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_empty() { + let list = RpcModulesList::parse("").unwrap(); + assert_eq!(list, RpcModulesList::new()); + } + + #[test] + fn test_parse_single() { + let list = RpcModulesList::parse("dev").unwrap(); + assert!(list.contains(&RpcModuleKind::Dev)); + } + + #[test] + fn test_parse_multiple() { + let list = RpcModulesList::parse("dev,torii,saya").unwrap(); + assert!(list.contains(&RpcModuleKind::Dev)); + assert!(list.contains(&RpcModuleKind::Torii)); + assert!(list.contains(&RpcModuleKind::Saya)); + } + + #[test] + fn test_parse_with_spaces() { + let list = RpcModulesList::parse(" dev , torii ").unwrap(); + assert!(list.contains(&RpcModuleKind::Dev)); + assert!(list.contains(&RpcModuleKind::Torii)); + } + + #[test] + fn test_parse_duplicates() { + let list = RpcModulesList::parse("dev,dev,torii").unwrap(); + let mut expected = RpcModulesList::new(); + expected.add(RpcModuleKind::Dev); + expected.add(RpcModuleKind::Torii); + assert_eq!(list, expected); + } + + #[test] + fn test_parse_invalid() { + assert!(RpcModulesList::parse("invalid").is_err()); + } +} diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index d220e45d7d..bebddacae0 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -11,7 +11,7 @@ use std::future::IntoFuture; use std::sync::Arc; use anyhow::Result; -use config::rpc::ApiKind; +use config::rpc::RpcModuleKind; use config::Config; use dojo_metrics::exporters::prometheus::PrometheusRecorder; use dojo_metrics::{Report, Server as MetricsServer}; @@ -252,7 +252,7 @@ pub async fn build(mut config: Config) -> Result { .allow_methods([Method::POST, Method::GET]) .allow_headers([hyper::header::CONTENT_TYPE, "argent-client".parse().unwrap(), "argent-version".parse().unwrap()]); - if config.rpc.apis.contains(&ApiKind::Starknet) { + if config.rpc.apis.contains(&RpcModuleKind::Starknet) { let cfg = StarknetApiConfig { max_event_page_size: config.rpc.max_event_page_size, max_proof_keys: config.rpc.max_proof_keys, @@ -275,17 +275,17 @@ pub async fn build(mut config: Config) -> Result { rpc_modules.merge(StarknetTraceApiServer::into_rpc(api))?; } - if config.rpc.apis.contains(&ApiKind::Dev) { + if config.rpc.apis.contains(&RpcModuleKind::Dev) { let api = DevApi::new(backend.clone(), block_producer.clone()); rpc_modules.merge(api.into_rpc())?; } - if config.rpc.apis.contains(&ApiKind::Torii) { + if config.rpc.apis.contains(&RpcModuleKind::Torii) { let api = ToriiApi::new(backend.clone(), pool.clone(), block_producer.clone()); rpc_modules.merge(api.into_rpc())?; } - if config.rpc.apis.contains(&ApiKind::Saya) { + if config.rpc.apis.contains(&RpcModuleKind::Saya) { let api = SayaApi::new(backend.clone(), block_producer.clone()); rpc_modules.merge(api.into_rpc())?; }