From 967ccb707c19f4420ca0aef2a5f6a65cd5e8a7f2 Mon Sep 17 00:00:00 2001 From: Annika Hannig Date: Wed, 21 Aug 2024 15:17:42 +0200 Subject: [PATCH] use new birdc interface --- src/api/neighbors.rs | 86 ++---------- src/api/responses.rs | 24 +++- src/api/server.rs | 10 +- src/api/status.rs | 18 +-- src/api/tables.rs | 52 ++----- src/bird.rs | 322 ++++++++++++++++++++++++++++++++++++------- src/lib.rs | 4 + src/main.rs | 28 +++- 8 files changed, 349 insertions(+), 195 deletions(-) diff --git a/src/api/neighbors.rs b/src/api/neighbors.rs index 91bddd3..8354d92 100644 --- a/src/api/neighbors.rs +++ b/src/api/neighbors.rs @@ -1,36 +1,21 @@ -use std::collections::HashMap; -use std::io::BufReader; - use anyhow::Result; use axum::extract::Path; -use tokio::task; use crate::{ api::{ responses::{NeighborsResponse, RoutesResponse}, Error, }, - bird, - parsers::{ - neighbors::NeighborReader, parser::BlockIterator, - routes::RE_ROUTES_START, routes_worker::RoutesWorkerPool, - }, - state::{Neighbor, Route}, + bird::{Birdc, ProtocolID}, }; /// List all neighbors (show protocols all, filter BGP) pub async fn list() -> Result { - let result = bird::birdc(bird::Command::ShowProtocolsAll)?; - let buf = BufReader::new(result); - let reader = NeighborReader::new(buf); - let neighbors: Vec = - reader.filter(|n| !n.id.is_empty()).collect(); - - let neighbors: HashMap = - neighbors.into_iter().map(|n| (n.id.clone(), n)).collect(); + let birdc = Birdc::default(); + let protocols = birdc.show_protocols_all().await?; let response = NeighborsResponse { - protocols: neighbors, + protocols, ..Default::default() }; let body = serde_json::to_string(&response)?; @@ -41,24 +26,9 @@ pub async fn list() -> Result { pub async fn list_routes_received( Path(id): Path, ) -> Result { - let result = bird::birdc(bird::Command::ShowRouteAllProtocol(id))?; - let buf = BufReader::new(result); - let blocks = BlockIterator::new(buf, &RE_ROUTES_START); - let mut routes: Vec = vec![]; - - // Spawn workers - let (blocks_tx, mut results_rx) = RoutesWorkerPool::spawn(); - - task::spawn_blocking(move || { - for block in blocks { - blocks_tx.send(block).unwrap(); - } - }); - - while let Some(result) = results_rx.recv().await { - let result = result?; - routes.extend(result); - } + let birdc = Birdc::default(); + let protocol = ProtocolID::parse(&id)?; + let routes = birdc.show_route_all_protocol(&protocol).await?; let response = RoutesResponse { routes, @@ -72,25 +42,9 @@ pub async fn list_routes_received( pub async fn list_routes_filtered( Path(id): Path, ) -> Result { - let result = bird::birdc(bird::Command::ShowRouteAllFilteredProtocol(id))?; - let buf = BufReader::new(result); - let blocks = BlockIterator::new(buf, &RE_ROUTES_START); - let mut routes: Vec = vec![]; - - // Spawn workers - let (blocks_tx, mut results_rx) = RoutesWorkerPool::spawn(); - - task::spawn_blocking(move || { - for block in blocks { - blocks_tx.send(block).unwrap(); - } - }); - - while let Some(result) = results_rx.recv().await { - let result = result?; - routes.extend(result); - } - + let birdc = Birdc::default(); + let protocol = ProtocolID::parse(&id)?; + let routes = birdc.show_route_all_filtered_protocol(&protocol).await?; let response = RoutesResponse { routes, ..Default::default() @@ -104,23 +58,9 @@ pub async fn list_routes_filtered( pub async fn list_routes_noexport( Path(id): Path, ) -> Result { - let result = bird::birdc(bird::Command::ShowRouteAllNoexportProtocol(id))?; - let buf = BufReader::new(result); - let blocks = BlockIterator::new(buf, &RE_ROUTES_START); - let mut routes: Vec = vec![]; - - // Spawn workers - let (blocks_tx, mut results_rx) = RoutesWorkerPool::spawn(); - task::spawn_blocking(move || { - for block in blocks { - blocks_tx.send(block).unwrap(); - } - }); - - while let Some(result) = results_rx.recv().await { - let result = result?; - routes.extend(result); - } + let birdc = Birdc::default(); + let protocol = ProtocolID::parse(&id)?; + let routes = birdc.show_route_all_noexport_protocol(&protocol).await?; let response = RoutesResponse { routes, diff --git a/src/api/responses.rs b/src/api/responses.rs index 582572a..a22cf95 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -22,16 +22,36 @@ impl Default for StatusResponse { } } -#[derive(Serialize, Deserialize, Debug, Default)] +#[derive(Serialize, Deserialize, Debug)] pub struct NeighborsResponse { pub api: ApiStatus, pub cached_at: DateTime, pub protocols: HashMap, } -#[derive(Serialize, Deserialize, Debug, Default)] +impl Default for NeighborsResponse { + fn default() -> Self { + NeighborsResponse { + api: ApiStatus::default(), + cached_at: Utc::now(), + protocols: HashMap::new(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] pub struct RoutesResponse { pub api: ApiStatus, pub cached_at: DateTime, pub routes: Vec, } + +impl Default for RoutesResponse { + fn default() -> Self { + RoutesResponse { + api: ApiStatus::default(), + cached_at: Utc::now(), + routes: Vec::new(), + } + } +} diff --git a/src/api/server.rs b/src/api/server.rs index 57e7cdd..e3f9ac3 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -1,5 +1,6 @@ use anyhow::Result; use axum::{routing::get, Router}; +use tower_http::trace::TraceLayer; use crate::api::{neighbors, status, tables}; @@ -17,7 +18,6 @@ async fn welcome() -> &'static str { /// Start the API http server pub async fn start(opts: &Opts) -> Result<()> { - let addr = opts.listen.parse()?; let app = Router::new() .route("/", get(welcome)) .route("/status", get(status::retrieve)) @@ -38,11 +38,11 @@ pub async fn start(opts: &Opts) -> Result<()> { .route( "/routes/table/:table/filtered", get(tables::list_routes_filtered), - ); + ) + .layer(TraceLayer::new_for_http()); - axum::Server::bind(&addr) - .serve(app.into_make_service()) - .await?; + let listener = tokio::net::TcpListener::bind(&opts.listen).await?; + axum::serve(listener, app).await?; Ok(()) } diff --git a/src/api/status.rs b/src/api/status.rs index e6a60fc..0acf685 100644 --- a/src/api/status.rs +++ b/src/api/status.rs @@ -1,19 +1,15 @@ -use std::io::{BufRead, BufReader}; - use anyhow::Result; -use crate::api::{responses::StatusResponse, Error}; -use crate::bird; -use crate::parsers::parser::Parse; -use crate::state::{ApiStatus, BirdStatus}; +use crate::{ + api::{responses::StatusResponse, Error}, + bird::Birdc, + state::ApiStatus, +}; /// Get the current status pub async fn retrieve() -> Result { - let result = bird::birdc(bird::Command::ShowStatus)?; - let reader = BufReader::new(result); - let block = reader.lines().map(|l| l.unwrap()).collect::>(); - let status = BirdStatus::parse(block).unwrap(); - + let birdc = Birdc::default(); + let status = birdc.show_status().await?; let response = StatusResponse { api: ApiStatus::default(), status, diff --git a/src/api/tables.rs b/src/api/tables.rs index cf28846..bfeafbd 100644 --- a/src/api/tables.rs +++ b/src/api/tables.rs @@ -1,40 +1,17 @@ -use std::io::BufReader; - use anyhow::Result; use axum::extract::Path; -use tokio::task; use crate::{ api::{responses::RoutesResponse, Error}, - bird, - parsers::{ - parser::BlockIterator, routes::RE_ROUTES_START, - routes_worker::RoutesWorkerPool, - }, - state::Route, + bird::{Birdc, TableID}, }; /// List all routes in a table pub async fn list_routes(Path(table): Path) -> Result { - let result = bird::birdc(bird::Command::ShowRouteAllTable(table))?; - let buf = BufReader::new(result); - let blocks = BlockIterator::new(buf, &RE_ROUTES_START); - let mut routes: Vec = vec![]; - - // Spawn workers - let (blocks_tx, mut results_rx) = RoutesWorkerPool::spawn(); - - task::spawn_blocking(move || { - for block in blocks { - blocks_tx.send(block).unwrap(); - } - }); - - while let Some(result) = results_rx.recv().await { - let result = result?; - routes.extend(result); - } + let birdc = Birdc::default(); + let table = TableID::parse(&table)?; + let routes = birdc.show_route_all_table(&table).await?; let response = RoutesResponse { routes, ..Default::default() @@ -47,23 +24,9 @@ pub async fn list_routes(Path(table): Path) -> Result { pub async fn list_routes_filtered( Path(table): Path, ) -> Result { - let result = bird::birdc(bird::Command::ShowRouteAllFilteredTable(table))?; - let buf = BufReader::new(result); - let blocks = BlockIterator::new(buf, &RE_ROUTES_START); - let mut routes: Vec = vec![]; - - // Spawn workers - let (blocks_tx, mut results_rx) = RoutesWorkerPool::spawn(); - task::spawn_blocking(move || { - for block in blocks { - blocks_tx.send(block).unwrap(); - } - }); - - while let Some(result) = results_rx.recv().await { - let result = result?; - routes.extend(result); - } + let birdc = Birdc::default(); + let table = TableID::parse(&table)?; + let routes = birdc.show_route_all_filtered_table(&table).await?; let response = RoutesResponse { routes, @@ -72,3 +35,4 @@ pub async fn list_routes_filtered( let body = serde_json::to_string(&response)?; Ok(body) } + diff --git a/src/bird.rs b/src/bird.rs index ed48f9d..d00c26a 100644 --- a/src/bird.rs +++ b/src/bird.rs @@ -1,68 +1,284 @@ -use std::io::Write; -use std::os::unix::net::UnixStream; +use std::{ + fmt::Display, + io::{BufReader, Write}, + os::unix::net::UnixStream, +}; use anyhow::Result; +use lazy_static::lazy_static; +use regex::Regex; +use thiserror::Error; +use tokio::task; -use crate::config; +use crate::{ + config, + parsers::{ + neighbors::NeighborReader, + parser::{BlockIterator, Parse}, + routes::RE_ROUTES_START, + routes_worker::RoutesWorkerPool, + }, + state::{BirdStatus, Neighbor, NeighborsMap, Route}, +}; -pub enum Command { - ShowStatus, - ShowProtocolsAll, - ShowRouteAllProtocol(String), - ShowRouteAllFilteredProtocol(String), - ShowRouteAllNoexportProtocol(String), - ShowRouteAllTable(String), - ShowRouteAllFilteredTable(String), +lazy_static! { + /// Regex for start neighbor + static ref RE_STATUS_START: Regex = Regex::new(r"\d\d\d\d\s").unwrap(); } -/// Remove potentially harmful characters from the string -fn sanitize_userdata(s: String) -> String { - s.replace("'", "_") - .replace("`", "_") - .replace("\"", "_") - .replace("\n", "_") - .replace("\t", "_") - .replace(",", "_") - .replace(";", "_") +#[derive(Error, Debug)] +pub struct ValidationError { + input: String, + reason: String, } -impl Into for Command { - fn into(self) -> String { - match self { - Command::ShowStatus => "show status\n".to_string(), - Command::ShowProtocolsAll => "show protocols all\n".to_string(), - Command::ShowRouteAllProtocol(id) => { - let id = sanitize_userdata(id); - format!("show route all protocol '{}'\n", id) - } - Command::ShowRouteAllFilteredProtocol(id) => { - let id = sanitize_userdata(id); - format!("show route all filtered protocol '{}'\n", id) - } - Command::ShowRouteAllNoexportProtocol(id) => { - let id = sanitize_userdata(id); - format!("show route all noexport protocol '{}'\n", id) - } - Command::ShowRouteAllTable(table) => { - let table = sanitize_userdata(table); - format!("show route all table '{}'\n", table) - } - Command::ShowRouteAllFilteredTable(table) => { - let table = sanitize_userdata(table); - format!("show route all filtered table '{}'\n", table) - } +impl Display for ValidationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Validation failed '{}': {}", self.input, self.reason) + } +} + +// Validation helpers + +/// Basic string validation +fn validate_string(s: &str) -> Result<()> { + if s.is_empty() { + return Err(ValidationError { + input: s.to_string(), + reason: "is empty".to_string(), + } + .into()); + } + + if s.len() > 128 { + return Err(ValidationError { + input: s.to_string(), + reason: "is too long".to_string(), + } + .into()); + } + + // Only allow [a-zA-Z0-9_] + if !s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { + return Err(ValidationError { + input: s.to_string(), + reason: "contains invalid characters".to_string(), + } + .into()); + } + + Ok(()) +} + +// Request Types + +/// TableID represents a table name like master4 +pub struct TableID(String); + +impl TableID { + /// Parse a table id from a string. This will fail + /// if the input is invalid. + pub fn parse(s: &str) -> Result { + let table = s.to_string(); + validate_string(&table)?; + + Ok(Self(table)) + } + + /// Get the table id as string + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Display for TableID { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// ProtocolID represents a neighbor identifier. +/// Valid characters are [a-zA-Z0-9_]. +pub struct ProtocolID(String); + +impl ProtocolID { + /// Parse a protocol id from a string. This will fail + /// if the input is invalid. + pub fn parse(s: &str) -> Result { + let protocol = s.to_string(); + validate_string(&protocol)?; + + Ok(Self(protocol)) + } + + /// Get the protocol id as string + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Display for ProtocolID { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +pub struct Birdc { + socket: String, +} + +impl Default for Birdc { + fn default() -> Self { + Self { + socket: config::get_birdc_socket(), } } } -/// Connect to birdc on the unix socket -/// and send the command. -pub fn birdc(cmd: Command) -> Result { - let socket_addr = config::get_birdc_socket(); - let mut stream = UnixStream::connect(socket_addr)?; - let req: String = cmd.into(); +impl Birdc { + /// Create new birdc instance + pub fn new(socket: String) -> Self { + Self { socket } + } + + /// Get the daemon status. + pub async fn show_status(&self) -> Result { + let mut stream = UnixStream::connect(&self.socket)?; + + let cmd = format!("show status\n"); + stream.write_all(&cmd.as_bytes())?; + + let reader = BufReader::new(stream); + let mut iter = BlockIterator::new(reader, &RE_STATUS_START); + let block = iter.next().unwrap(); + let status = BirdStatus::parse(block)?; + + Ok(status) + } + + /// Get neighbors + pub async fn show_protocols_all(&self) -> Result { + let mut stream = UnixStream::connect(&self.socket)?; + let cmd = format!("show protocols all\n"); + stream.write_all(&cmd.as_bytes())?; + + let buf = BufReader::new(stream); + let reader = NeighborReader::new(buf); + let neighbors: Vec = + reader.filter(|n| !n.id.is_empty()).collect(); + + let neighbors: NeighborsMap = + neighbors.into_iter().map(|n| (n.id.clone(), n)).collect(); + + Ok(neighbors) + } + + /// Send the command to the birdc socket and parse the response. + /// Please note that only show route commands can be used here. + async fn fetch_routes_cmd(&self, cmd: &str) -> Result> { + let mut stream = UnixStream::connect(&self.socket)?; + stream.write_all(&cmd.as_bytes())?; + let buf = BufReader::new(stream); + + let blocks = BlockIterator::new(buf, &RE_ROUTES_START); + let mut routes: Vec = vec![]; + + // Spawn workers and fill queue + let (blocks_tx, mut results_rx) = RoutesWorkerPool::spawn(); + task::spawn_blocking(move || { + for block in blocks { + blocks_tx.send(block).unwrap(); + } + }); + + // Collect results + while let Some(result) = results_rx.recv().await { + let result = result?; + routes.extend(result); + } + + Ok(routes) + } + + /// Get routes for a table + pub async fn show_route_all_table( + &self, + table: &TableID, + ) -> Result> { + let cmd = format!("show route all table '{}'\n", table); + let routes = self.fetch_routes_cmd(&cmd).await?; + Ok(routes) + } + + /// Get filtered routes for a table + pub async fn show_route_all_filtered_table( + &self, + table: &TableID, + ) -> Result> { + let cmd = format!("show route all filtered table '{}'\n", table); + let routes = self.fetch_routes_cmd(&cmd).await?; + Ok(routes) + } + + /// Get routes for a neighbor + pub async fn show_route_all_protocol( + &self, + protocol: &ProtocolID, + ) -> Result> { + let cmd = format!("show route all protocol '{}'\n", protocol); + let routes = self.fetch_routes_cmd(&cmd).await?; + Ok(routes) + } + + /// Get routes for a neighbor + pub async fn show_route_all_filtered_protocol( + &self, + protocol: &ProtocolID, + ) -> Result> { + let cmd = format!("show route all filtered protocol '{}'\n", protocol); + let routes = self.fetch_routes_cmd(&cmd).await?; + Ok(routes) + } - stream.write_all(&req.as_bytes())?; - Ok(stream) + /// Get noexport routes for a neighbor + pub async fn show_route_all_noexport_protocol( + &self, + protocol: &ProtocolID, + ) -> Result> { + // TODO: check command + let cmd = format!("show route all noexport protocol '{}'\n", protocol); + let routes = self.fetch_routes_cmd(&cmd).await?; + Ok(routes) + } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_table_id() { + let table = TableID::parse("master4").unwrap(); + assert_eq!(table.as_str(), "master4"); + + // Invalid table name + let result = TableID::parse("m4'"); + assert!(result.is_err()); + } + + #[test] + fn test_protocol_id() { + let protocol = ProtocolID::parse("R192_175").unwrap(); + assert_eq!(protocol.as_str(), "R192_175"); + + // Invalid table name + let result = ProtocolID::parse("R192_175'"); + assert!(result.is_err()); + + let result = ProtocolID::parse("R192 175"); + assert!(result.is_err()); + + let result = ProtocolID::parse("R192`date`175"); + assert!(result.is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index eb15002..1e964d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,3 +3,7 @@ pub mod bird; pub mod config; pub mod parsers; pub mod state; + +pub fn version() -> String { + env!("CARGO_PKG_VERSION").to_string() +} diff --git a/src/main.rs b/src/main.rs index 2dedcfc..1ef9960 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,34 @@ use anyhow::Result; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use lightwatcher::api::{self, server::Opts}; use lightwatcher::config; #[tokio::main] async fn main() -> Result<()> { - // Print info - let listen = config::get_listen_address(); - let birdc_socket = config::get_birdc_socket(); + // Setup tracing + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| { + "lightwatcher=info,axum::rejection=trace,tower_http=debug" + .into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); - println!("lightwatcher v0.0.1\n"); - println!(" LIGHTWATCHER_LISTEN: {}", listen); - println!(" LIGHTWATCHER_BIRDC: {}", birdc_socket); - println!("\n"); + // Print info + tracing::info!("starting {}", lightwatcher::version()); + tracing::info!( + "ENV: LIGHTWATCHER_LISTEN={}", + config::get_listen_address() + ); + tracing::info!("ENV: LIGHTWATCHER_BIRDC={}", config::get_birdc_socket()); // Start API server + let listen = config::get_listen_address(); api::server::start(&Opts { listen }).await?; + Ok(()) }