Skip to content

Commit

Permalink
feat: Data table improvements
Browse files Browse the repository at this point in the history
- More informative error message when failing to load a
table reference.
- Implement loading array data with column headers.
  • Loading branch information
jetuk committed Oct 7, 2024
1 parent 030a7e9 commit da82e04
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 19 deletions.
25 changes: 20 additions & 5 deletions pywr-schema/src/data_tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ mod scalar;
#[cfg(feature = "core")]
mod vec;

use crate::data_tables::vec::load_csv_col2_vec_table_one;
use crate::parameters::TableIndex;
use crate::ConversionError;
use crate::{ConversionError, SchemaError};
use pywr_schema_macros::PywrVisitAll;
use pywr_v1_schema::parameters::TableDataRef as TableDataRefV1;
#[cfg(feature = "core")]
Expand Down Expand Up @@ -108,7 +109,17 @@ impl CsvDataTable {
"CSV row array table with more than two index columns is not supported.".to_string(),
)),
},
CsvDataTableLookup::Col(_) => todo!(),
CsvDataTableLookup::Col(i) => match i {
1 => Ok(LoadedTable::FloatVec(load_csv_col2_vec_table_one(
&self.url, data_path,
)?)),
2 => Ok(LoadedTable::FloatVec(load_csv_col2_vec_table_one(
&self.url, data_path,
)?)),
_ => Err(TableError::FormatNotSupported(
"CSV column array table with more than two index columns is not supported.".to_string(),
)),
},
CsvDataTableLookup::Both(_, _) => todo!(),
},
}
Expand Down Expand Up @@ -156,6 +167,8 @@ pub enum TableError {
TooManyValues(PathBuf),
#[error("table index out of bounds: {0}")]
IndexOutOfBounds(usize),
#[error("Table format invalid: {0}")]
InvalidFormat(String),
}

#[cfg(feature = "core")]
Expand Down Expand Up @@ -192,13 +205,16 @@ pub struct LoadedTableCollection {

#[cfg(feature = "core")]
impl LoadedTableCollection {
pub fn from_schema(table_defs: Option<&[DataTable]>, data_path: Option<&Path>) -> Result<Self, TableError> {
pub fn from_schema(table_defs: Option<&[DataTable]>, data_path: Option<&Path>) -> Result<Self, SchemaError> {
let mut tables = HashMap::new();
if let Some(table_defs) = table_defs {
for table_def in table_defs {
let name = table_def.name().to_string();
info!("Loading table: {}", &name);
let table = table_def.load(data_path)?;
let table = table_def.load(data_path).map_err(|error| SchemaError::TableLoad {
table_def: table_def.clone(),
error,
})?;
// TODO handle duplicate table names!
tables.insert(name, table);
}
Expand All @@ -215,7 +231,6 @@ impl LoadedTableCollection {

/// Return a single scalar value from a table collection.
pub fn get_scalar_f64(&self, table_ref: &TableDataRef) -> Result<f64, TableError> {
debug!("Looking-up float scalar with reference: {:?}", table_ref);
let tbl = self.get_table(&table_ref.table)?;
let key = table_ref.key();
tbl.get_scalar_f64(&key)
Expand Down
15 changes: 9 additions & 6 deletions pywr-schema/src/data_tables/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where
pub struct ScalarTableR1C1<T> {
index: (Vec<String>, Vec<String>),
// Could this be flattened for a small performance gain?
values: Vec<Vec<T>>,
values: Vec<Vec<Option<T>>>,
}

impl<T> ScalarTableR1C1<T>
Expand All @@ -44,12 +44,15 @@ where
let idx0 = table_key_to_position(index[0], &self.index.0)?;
let idx1 = table_key_to_position(index[1], &self.index.1)?;

self.values
let value = self
.values
.get(idx0)
.ok_or(TableError::IndexOutOfBounds(idx0))?
.get(idx1)
.ok_or(TableError::IndexOutOfBounds(idx1))
.copied()
.ok_or(TableError::IndexOutOfBounds(idx1))?
.ok_or_else(|| TableError::EntryNotFound)?;

Ok(value)
} else {
Err(TableError::WrongKeySize(2, index.len()))
}
Expand Down Expand Up @@ -149,7 +152,7 @@ where
.collect();

let mut row_headers: Vec<String> = Vec::new();
let values: Vec<Vec<T>> = rdr
let values: Vec<Vec<Option<T>>> = rdr
.records()
.map(|result| {
// The iterator yields Result<StringRecord, Error>, so we check the
Expand All @@ -158,7 +161,7 @@ where

let key = record.get(0).ok_or(TableError::KeyParse)?.to_string();

let values: Vec<T> = record.iter().skip(1).map(|v| v.parse()).collect::<Result<_, _>>()?;
let values: Vec<Option<T>> = record.iter().skip(1).map(|v| v.parse::<T>().ok()).collect();

row_headers.push(key.clone());

Expand Down
109 changes: 109 additions & 0 deletions pywr-schema/src/data_tables/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,112 @@ where

Ok(LoadedVecTable::Two(tbl))
}

pub fn load_csv_col1_vec_table_one<T>(
table_path: &Path,
data_path: Option<&Path>,
) -> Result<LoadedVecTable<T>, TableError>
where
T: FromStr,
TableError: From<T::Err>,
{
let path = make_path(table_path, data_path);

let file = File::open(path).map_err(|e| TableError::IO(e.to_string()))?;
let buf_reader = BufReader::new(file);
let mut rdr = csv::Reader::from_reader(buf_reader);

let mut tbl: HashMap<String, Vec<T>> = HashMap::new();

// Read the headers
let headers: Vec<String> = rdr
.headers()
.map_err(|e| TableError::Csv(e.to_string()))?
.iter()
.map(|s| s.to_string())
.collect();

for header in headers.iter() {
tbl.insert(header.clone(), Vec::new());
}

for result in rdr.records() {
// The iterator yields Result<StringRecord, Error>, so we check the
// error here.
let record = result.map_err(|e| TableError::Csv(e.to_string()))?;

for (col_idx, value) in record.iter().enumerate() {
let value: T = value.parse()?;
let key = headers.get(col_idx).ok_or_else(|| {
TableError::InvalidFormat(format!(
"Value index ({}) is out of bounds for a table with {} headers.",
col_idx,
headers.len()
))
})?;
tbl.get_mut(key).unwrap().push(value);
}
}

Ok(LoadedVecTable::One(tbl))
}

pub fn load_csv_col2_vec_table_one<T>(
table_path: &Path,
data_path: Option<&Path>,
) -> Result<LoadedVecTable<T>, TableError>
where
T: FromStr,
TableError: From<T::Err>,
{
let path = make_path(table_path, data_path);

let file = File::open(path).map_err(|e| TableError::IO(e.to_string()))?;
let buf_reader = BufReader::new(file);
let mut rdr = csv::Reader::from_reader(buf_reader);

let mut tbl: HashMap<(String, String), Vec<T>> = HashMap::new();

// Read the headers
let headers1: Vec<String> = rdr
.headers()
.map_err(|e| TableError::Csv(e.to_string()))?
.iter()
.map(|s| s.to_string())
.collect();

let mut records = rdr.records();
// Read the second row as the second headers
let headers2: Vec<String> = records
.next()
.ok_or_else(|| TableError::WrongTableFormat("Second row of headers found".to_string()))?
.map_err(|e| TableError::Csv(e.to_string()))?
.iter()
.map(|s| s.to_string())
.collect();

let headers: Vec<_> = headers1.into_iter().zip(headers2.into_iter()).collect();
for header in &headers {
tbl.insert(header.clone(), Vec::new());
}

for result in records {
// The iterator yields Result<StringRecord, Error>, so we check the
// error here.
let record = result.map_err(|e| TableError::Csv(e.to_string()))?;

for (col_idx, value) in record.iter().enumerate() {
let value: T = value.parse()?;
let key = headers.get(col_idx).ok_or_else(|| {
TableError::InvalidFormat(format!(
"Value index ({}) is out of bounds for a table with {} headers.",
col_idx,
headers.len()
))
})?;
tbl.get_mut(key).unwrap().push(value);
}
}

Ok(LoadedVecTable::Two(tbl))
}
8 changes: 5 additions & 3 deletions pywr-schema/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::data_tables::TableError;
use crate::data_tables::{DataTable, TableDataRef, TableError};
use crate::nodes::NodeAttribute;
use crate::timeseries::TimeseriesError;
use thiserror::Error;
Expand Down Expand Up @@ -26,8 +26,10 @@ pub enum SchemaError {
#[error("Pywr core error: {0}")]
#[cfg(feature = "core")]
PywrCore(#[from] pywr_core::PywrError),
#[error("data table error: {0}")]
DataTable(#[from] TableError),
#[error("Error loading data from table `{0}` (column: `{1:?}`, index: `{2:?}`) error: {error}", table_ref.table, table_ref.column, table_ref.index)]
TableRefLoad { table_ref: TableDataRef, error: TableError },
#[error("Error loading table `{table_def:?}` error: {error}")]
TableLoad { table_def: DataTable, error: TableError },
#[error("Circular node reference(s) found.")]
CircularNodeReference,
#[error("Circular parameters reference(s) found. Unable to load the following parameters: {0:?}")]
Expand Down
8 changes: 7 additions & 1 deletion pywr-schema/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ impl Metric {
Self::Parameter(parameter_ref) => parameter_ref.load(network),
Self::Constant { value } => Ok((*value).into()),
Self::Table(table_ref) => {
let value = args.tables.get_scalar_f64(table_ref)?;
let value = args
.tables
.get_scalar_f64(table_ref)
.map_err(|error| SchemaError::TableRefLoad {
table_ref: table_ref.clone(),
error,
})?;
Ok(value.into())
}
Self::Timeseries(ts_ref) => {
Expand Down
2 changes: 1 addition & 1 deletion pywr-schema/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ impl PywrNetwork {

#[cfg(feature = "core")]
pub fn load_tables(&self, data_path: Option<&Path>) -> Result<LoadedTableCollection, SchemaError> {
Ok(LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?)
LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)
}

#[cfg(feature = "core")]
Expand Down
22 changes: 19 additions & 3 deletions pywr-schema/src/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,12 @@ impl ConstantValue<f64> {
pub fn load(&self, tables: &LoadedTableCollection) -> Result<f64, SchemaError> {
match self {
Self::Literal(v) => Ok(*v),
Self::Table(tbl_ref) => Ok(tables.get_scalar_f64(tbl_ref)?),
Self::Table(tbl_ref) => tables
.get_scalar_f64(tbl_ref)
.map_err(|error| SchemaError::TableRefLoad {
table_ref: tbl_ref.clone(),
error,
}),
}
}
}
Expand All @@ -736,7 +741,12 @@ impl ConstantValue<usize> {
pub fn load(&self, tables: &LoadedTableCollection) -> Result<usize, SchemaError> {
match self {
Self::Literal(v) => Ok(*v),
Self::Table(tbl_ref) => Ok(tables.get_scalar_usize(tbl_ref)?),
Self::Table(tbl_ref) => tables
.get_scalar_usize(tbl_ref)
.map_err(|error| SchemaError::TableRefLoad {
table_ref: tbl_ref.clone(),
error,
}),
}
}
}
Expand Down Expand Up @@ -867,7 +877,13 @@ impl ConstantFloatVec {
pub fn load(&self, tables: &LoadedTableCollection) -> Result<Vec<f64>, SchemaError> {
match self {
Self::Literal(v) => Ok(v.clone()),
Self::Table(tbl_ref) => Ok(tables.get_vec_f64(tbl_ref)?.clone()),
Self::Table(tbl_ref) => tables
.get_vec_f64(tbl_ref)
.cloned()
.map_err(|error| SchemaError::TableRefLoad {
table_ref: tbl_ref.clone(),
error,
}),
}
}
}
Expand Down

0 comments on commit da82e04

Please sign in to comment.