diff --git a/pywr-schema/src/data_tables/mod.rs b/pywr-schema/src/data_tables/mod.rs index cb83cf80..16ef5fa3 100644 --- a/pywr-schema/src/data_tables/mod.rs +++ b/pywr-schema/src/data_tables/mod.rs @@ -5,6 +5,8 @@ mod vec; use crate::parameters::TableIndex; use crate::ConversionError; +#[cfg(feature = "core")] +use crate::SchemaError; use pywr_schema_macros::PywrVisitAll; use pywr_v1_schema::parameters::TableDataRef as TableDataRefV1; #[cfg(feature = "core")] @@ -19,7 +21,10 @@ use thiserror::Error; #[cfg(feature = "core")] use tracing::{debug, info}; #[cfg(feature = "core")] -use vec::{load_csv_row2_vec_table_one, load_csv_row_vec_table_one, LoadedVecTable}; +use vec::{ + load_csv_col1_vec_table_one, load_csv_col2_vec_table_two, load_csv_row2_vec_table_one, load_csv_row_vec_table_one, + LoadedVecTable, +}; #[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)] #[serde(rename_all = "lowercase")] @@ -108,7 +113,17 @@ impl CsvDataTable { "CSV row array table with more than two index columns is not supported.".to_string(), )), }, - CsvDataTableLookup::Col(_) => todo!(), + CsvDataTableLookup::Col(i) => match i { + 1 => Ok(LoadedTable::FloatVec(load_csv_col1_vec_table_one( + &self.url, data_path, + )?)), + 2 => Ok(LoadedTable::FloatVec(load_csv_col2_vec_table_two( + &self.url, data_path, + )?)), + _ => Err(TableError::FormatNotSupported( + "CSV column array table with more than two index columns is not supported.".to_string(), + )), + }, CsvDataTableLookup::Both(_, _) => todo!(), }, } @@ -156,6 +171,8 @@ pub enum TableError { TooManyValues(PathBuf), #[error("table index out of bounds: {0}")] IndexOutOfBounds(usize), + #[error("Table format invalid: {0}")] + InvalidFormat(String), } #[cfg(feature = "core")] @@ -192,13 +209,16 @@ pub struct LoadedTableCollection { #[cfg(feature = "core")] impl LoadedTableCollection { - pub fn from_schema(table_defs: Option<&[DataTable]>, data_path: Option<&Path>) -> Result { + pub fn from_schema(table_defs: Option<&[DataTable]>, data_path: Option<&Path>) -> Result { let mut tables = HashMap::new(); if let Some(table_defs) = table_defs { for table_def in table_defs { let name = table_def.name().to_string(); info!("Loading table: {}", &name); - let table = table_def.load(data_path)?; + let table = table_def.load(data_path).map_err(|error| SchemaError::TableLoad { + table_def: table_def.clone(), + error, + })?; // TODO handle duplicate table names! tables.insert(name, table); } @@ -215,7 +235,6 @@ impl LoadedTableCollection { /// Return a single scalar value from a table collection. pub fn get_scalar_f64(&self, table_ref: &TableDataRef) -> Result { - debug!("Looking-up float scalar with reference: {:?}", table_ref); let tbl = self.get_table(&table_ref.table)?; let key = table_ref.key(); tbl.get_scalar_f64(&key) diff --git a/pywr-schema/src/data_tables/scalar.rs b/pywr-schema/src/data_tables/scalar.rs index 2f0bc643..dcf2bf70 100644 --- a/pywr-schema/src/data_tables/scalar.rs +++ b/pywr-schema/src/data_tables/scalar.rs @@ -32,7 +32,7 @@ where pub struct ScalarTableR1C1 { index: (Vec, Vec), // Could this be flattened for a small performance gain? - values: Vec>, + values: Vec>>, } impl ScalarTableR1C1 @@ -44,12 +44,15 @@ where let idx0 = table_key_to_position(index[0], &self.index.0)?; let idx1 = table_key_to_position(index[1], &self.index.1)?; - self.values + let value = self + .values .get(idx0) .ok_or(TableError::IndexOutOfBounds(idx0))? .get(idx1) - .ok_or(TableError::IndexOutOfBounds(idx1)) - .copied() + .ok_or(TableError::IndexOutOfBounds(idx1))? + .ok_or_else(|| TableError::EntryNotFound)?; + + Ok(value) } else { Err(TableError::WrongKeySize(2, index.len())) } @@ -149,7 +152,7 @@ where .collect(); let mut row_headers: Vec = Vec::new(); - let values: Vec> = rdr + let values: Vec>> = rdr .records() .map(|result| { // The iterator yields Result, so we check the @@ -158,7 +161,7 @@ where let key = record.get(0).ok_or(TableError::KeyParse)?.to_string(); - let values: Vec = record.iter().skip(1).map(|v| v.parse()).collect::>()?; + let values: Vec> = record.iter().skip(1).map(|v| v.parse::().ok()).collect(); row_headers.push(key.clone()); diff --git a/pywr-schema/src/data_tables/vec.rs b/pywr-schema/src/data_tables/vec.rs index 0968e1ea..24c81f8b 100644 --- a/pywr-schema/src/data_tables/vec.rs +++ b/pywr-schema/src/data_tables/vec.rs @@ -112,3 +112,112 @@ where Ok(LoadedVecTable::Two(tbl)) } + +pub fn load_csv_col1_vec_table_one( + table_path: &Path, + data_path: Option<&Path>, +) -> Result, TableError> +where + T: FromStr, + TableError: From, +{ + let path = make_path(table_path, data_path); + + let file = File::open(path).map_err(|e| TableError::IO(e.to_string()))?; + let buf_reader = BufReader::new(file); + let mut rdr = csv::Reader::from_reader(buf_reader); + + let mut tbl: HashMap> = HashMap::new(); + + // Read the headers + let headers: Vec = rdr + .headers() + .map_err(|e| TableError::Csv(e.to_string()))? + .iter() + .map(|s| s.to_string()) + .collect(); + + for header in headers.iter() { + tbl.insert(header.clone(), Vec::new()); + } + + for result in rdr.records() { + // The iterator yields Result, so we check the + // error here. + let record = result.map_err(|e| TableError::Csv(e.to_string()))?; + + for (col_idx, value) in record.iter().enumerate() { + let value: T = value.parse()?; + let key = headers.get(col_idx).ok_or_else(|| { + TableError::InvalidFormat(format!( + "Value index ({}) is out of bounds for a table with {} headers.", + col_idx, + headers.len() + )) + })?; + tbl.get_mut(key).unwrap().push(value); + } + } + + Ok(LoadedVecTable::One(tbl)) +} + +pub fn load_csv_col2_vec_table_two( + table_path: &Path, + data_path: Option<&Path>, +) -> Result, TableError> +where + T: FromStr, + TableError: From, +{ + let path = make_path(table_path, data_path); + + let file = File::open(path).map_err(|e| TableError::IO(e.to_string()))?; + let buf_reader = BufReader::new(file); + let mut rdr = csv::Reader::from_reader(buf_reader); + + let mut tbl: HashMap<(String, String), Vec> = HashMap::new(); + + // Read the headers + let headers1: Vec = rdr + .headers() + .map_err(|e| TableError::Csv(e.to_string()))? + .iter() + .map(|s| s.to_string()) + .collect(); + + let mut records = rdr.records(); + // Read the second row as the second headers + let headers2: Vec = records + .next() + .ok_or_else(|| TableError::WrongTableFormat("Second row of headers found".to_string()))? + .map_err(|e| TableError::Csv(e.to_string()))? + .iter() + .map(|s| s.to_string()) + .collect(); + + let headers: Vec<_> = headers1.into_iter().zip(headers2).collect(); + for header in &headers { + tbl.insert(header.clone(), Vec::new()); + } + + for result in records { + // The iterator yields Result, so we check the + // error here. + let record = result.map_err(|e| TableError::Csv(e.to_string()))?; + + for (col_idx, value) in record.iter().enumerate() { + let value: T = value.parse()?; + let key = headers.get(col_idx).ok_or_else(|| { + TableError::InvalidFormat(format!( + "Value index ({}) is out of bounds for a table with {} headers.", + col_idx, + headers.len() + )) + })?; + tbl.get_mut(key).unwrap().push(value); + } + } + + Ok(LoadedVecTable::Two(tbl)) +} diff --git a/pywr-schema/src/error.rs b/pywr-schema/src/error.rs index d079df0c..d67cee32 100644 --- a/pywr-schema/src/error.rs +++ b/pywr-schema/src/error.rs @@ -1,4 +1,4 @@ -use crate::data_tables::TableError; +use crate::data_tables::{DataTable, TableDataRef, TableError}; use crate::nodes::NodeAttribute; use crate::timeseries::TimeseriesError; use thiserror::Error; @@ -26,8 +26,10 @@ pub enum SchemaError { #[error("Pywr core error: {0}")] #[cfg(feature = "core")] PywrCore(#[from] pywr_core::PywrError), - #[error("data table error: {0}")] - DataTable(#[from] TableError), + #[error("Error loading data from table `{0}` (column: `{1:?}`, index: `{2:?}`) error: {error}", table_ref.table, table_ref.column, table_ref.index)] + TableRefLoad { table_ref: TableDataRef, error: TableError }, + #[error("Error loading table `{table_def:?}` error: {error}")] + TableLoad { table_def: DataTable, error: TableError }, #[error("Circular node reference(s) found.")] CircularNodeReference, #[error("Circular parameters reference(s) found. Unable to load the following parameters: {0:?}")] diff --git a/pywr-schema/src/metric.rs b/pywr-schema/src/metric.rs index 5509b5ba..93c0c546 100644 --- a/pywr-schema/src/metric.rs +++ b/pywr-schema/src/metric.rs @@ -58,7 +58,13 @@ impl Metric { Self::Parameter(parameter_ref) => parameter_ref.load(network), Self::Constant { value } => Ok((*value).into()), Self::Table(table_ref) => { - let value = args.tables.get_scalar_f64(table_ref)?; + let value = args + .tables + .get_scalar_f64(table_ref) + .map_err(|error| SchemaError::TableRefLoad { + table_ref: table_ref.clone(), + error, + })?; Ok(value.into()) } Self::Timeseries(ts_ref) => { diff --git a/pywr-schema/src/model.rs b/pywr-schema/src/model.rs index 5d4abaec..9c78e52b 100644 --- a/pywr-schema/src/model.rs +++ b/pywr-schema/src/model.rs @@ -382,7 +382,7 @@ impl PywrNetwork { #[cfg(feature = "core")] pub fn load_tables(&self, data_path: Option<&Path>) -> Result { - Ok(LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?) + LoadedTableCollection::from_schema(self.tables.as_deref(), data_path) } #[cfg(feature = "core")] diff --git a/pywr-schema/src/parameters/mod.rs b/pywr-schema/src/parameters/mod.rs index 14063c01..ac946e90 100644 --- a/pywr-schema/src/parameters/mod.rs +++ b/pywr-schema/src/parameters/mod.rs @@ -725,7 +725,12 @@ impl ConstantValue { pub fn load(&self, tables: &LoadedTableCollection) -> Result { match self { Self::Literal(v) => Ok(*v), - Self::Table(tbl_ref) => Ok(tables.get_scalar_f64(tbl_ref)?), + Self::Table(tbl_ref) => tables + .get_scalar_f64(tbl_ref) + .map_err(|error| SchemaError::TableRefLoad { + table_ref: tbl_ref.clone(), + error, + }), } } } @@ -736,7 +741,12 @@ impl ConstantValue { pub fn load(&self, tables: &LoadedTableCollection) -> Result { match self { Self::Literal(v) => Ok(*v), - Self::Table(tbl_ref) => Ok(tables.get_scalar_usize(tbl_ref)?), + Self::Table(tbl_ref) => tables + .get_scalar_usize(tbl_ref) + .map_err(|error| SchemaError::TableRefLoad { + table_ref: tbl_ref.clone(), + error, + }), } } } @@ -867,7 +877,13 @@ impl ConstantFloatVec { pub fn load(&self, tables: &LoadedTableCollection) -> Result, SchemaError> { match self { Self::Literal(v) => Ok(v.clone()), - Self::Table(tbl_ref) => Ok(tables.get_vec_f64(tbl_ref)?.clone()), + Self::Table(tbl_ref) => tables + .get_vec_f64(tbl_ref) + .cloned() + .map_err(|error| SchemaError::TableRefLoad { + table_ref: tbl_ref.clone(), + error, + }), } } }