Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Data table improvements #264

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions pywr-schema/src/data_tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ mod vec;

use crate::parameters::TableIndex;
use crate::ConversionError;
#[cfg(feature = "core")]
use crate::SchemaError;
use pywr_schema_macros::PywrVisitAll;
use pywr_v1_schema::parameters::TableDataRef as TableDataRefV1;
#[cfg(feature = "core")]
Expand All @@ -19,7 +21,10 @@ use thiserror::Error;
#[cfg(feature = "core")]
use tracing::{debug, info};
#[cfg(feature = "core")]
use vec::{load_csv_row2_vec_table_one, load_csv_row_vec_table_one, LoadedVecTable};
use vec::{
load_csv_col1_vec_table_one, load_csv_col2_vec_table_two, load_csv_row2_vec_table_one, load_csv_row_vec_table_one,
LoadedVecTable,
};

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(rename_all = "lowercase")]
Expand Down Expand Up @@ -108,7 +113,17 @@ impl CsvDataTable {
"CSV row array table with more than two index columns is not supported.".to_string(),
)),
},
CsvDataTableLookup::Col(_) => todo!(),
CsvDataTableLookup::Col(i) => match i {
1 => Ok(LoadedTable::FloatVec(load_csv_col1_vec_table_one(
&self.url, data_path,
)?)),
2 => Ok(LoadedTable::FloatVec(load_csv_col2_vec_table_two(
&self.url, data_path,
)?)),
_ => Err(TableError::FormatNotSupported(
"CSV column array table with more than two index columns is not supported.".to_string(),
)),
},
CsvDataTableLookup::Both(_, _) => todo!(),
},
}
Expand Down Expand Up @@ -156,6 +171,8 @@ pub enum TableError {
TooManyValues(PathBuf),
#[error("table index out of bounds: {0}")]
IndexOutOfBounds(usize),
#[error("Table format invalid: {0}")]
InvalidFormat(String),
}

#[cfg(feature = "core")]
Expand Down Expand Up @@ -192,13 +209,16 @@ pub struct LoadedTableCollection {

#[cfg(feature = "core")]
impl LoadedTableCollection {
pub fn from_schema(table_defs: Option<&[DataTable]>, data_path: Option<&Path>) -> Result<Self, TableError> {
pub fn from_schema(table_defs: Option<&[DataTable]>, data_path: Option<&Path>) -> Result<Self, SchemaError> {
let mut tables = HashMap::new();
if let Some(table_defs) = table_defs {
for table_def in table_defs {
let name = table_def.name().to_string();
info!("Loading table: {}", &name);
let table = table_def.load(data_path)?;
let table = table_def.load(data_path).map_err(|error| SchemaError::TableLoad {
table_def: table_def.clone(),
error,
})?;
// TODO handle duplicate table names!
tables.insert(name, table);
}
Expand All @@ -215,7 +235,6 @@ impl LoadedTableCollection {

/// Return a single scalar value from a table collection.
pub fn get_scalar_f64(&self, table_ref: &TableDataRef) -> Result<f64, TableError> {
debug!("Looking-up float scalar with reference: {:?}", table_ref);
let tbl = self.get_table(&table_ref.table)?;
let key = table_ref.key();
tbl.get_scalar_f64(&key)
Expand Down
15 changes: 9 additions & 6 deletions pywr-schema/src/data_tables/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where
pub struct ScalarTableR1C1<T> {
index: (Vec<String>, Vec<String>),
// Could this be flattened for a small performance gain?
values: Vec<Vec<T>>,
values: Vec<Vec<Option<T>>>,
}

impl<T> ScalarTableR1C1<T>
Expand All @@ -44,12 +44,15 @@ where
let idx0 = table_key_to_position(index[0], &self.index.0)?;
let idx1 = table_key_to_position(index[1], &self.index.1)?;

self.values
let value = self
.values
.get(idx0)
.ok_or(TableError::IndexOutOfBounds(idx0))?
.get(idx1)
.ok_or(TableError::IndexOutOfBounds(idx1))
.copied()
.ok_or(TableError::IndexOutOfBounds(idx1))?
.ok_or_else(|| TableError::EntryNotFound)?;

Ok(value)
} else {
Err(TableError::WrongKeySize(2, index.len()))
}
Expand Down Expand Up @@ -149,7 +152,7 @@ where
.collect();

let mut row_headers: Vec<String> = Vec::new();
let values: Vec<Vec<T>> = rdr
let values: Vec<Vec<Option<T>>> = rdr
.records()
.map(|result| {
// The iterator yields Result<StringRecord, Error>, so we check the
Expand All @@ -158,7 +161,7 @@ where

let key = record.get(0).ok_or(TableError::KeyParse)?.to_string();

let values: Vec<T> = record.iter().skip(1).map(|v| v.parse()).collect::<Result<_, _>>()?;
let values: Vec<Option<T>> = record.iter().skip(1).map(|v| v.parse::<T>().ok()).collect();

row_headers.push(key.clone());

Expand Down
109 changes: 109 additions & 0 deletions pywr-schema/src/data_tables/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,112 @@ where

Ok(LoadedVecTable::Two(tbl))
}

pub fn load_csv_col1_vec_table_one<T>(
table_path: &Path,
data_path: Option<&Path>,
) -> Result<LoadedVecTable<T>, TableError>
where
T: FromStr,
TableError: From<T::Err>,
{
let path = make_path(table_path, data_path);

let file = File::open(path).map_err(|e| TableError::IO(e.to_string()))?;
let buf_reader = BufReader::new(file);
let mut rdr = csv::Reader::from_reader(buf_reader);

let mut tbl: HashMap<String, Vec<T>> = HashMap::new();

// Read the headers
let headers: Vec<String> = rdr
.headers()
.map_err(|e| TableError::Csv(e.to_string()))?
.iter()
.map(|s| s.to_string())
.collect();

for header in headers.iter() {
tbl.insert(header.clone(), Vec::new());
}

for result in rdr.records() {
// The iterator yields Result<StringRecord, Error>, so we check the
// error here.
let record = result.map_err(|e| TableError::Csv(e.to_string()))?;

for (col_idx, value) in record.iter().enumerate() {
let value: T = value.parse()?;
let key = headers.get(col_idx).ok_or_else(|| {
TableError::InvalidFormat(format!(
"Value index ({}) is out of bounds for a table with {} headers.",
col_idx,
headers.len()
))
})?;
tbl.get_mut(key).unwrap().push(value);
}
}

Ok(LoadedVecTable::One(tbl))
}

pub fn load_csv_col2_vec_table_two<T>(
table_path: &Path,
data_path: Option<&Path>,
) -> Result<LoadedVecTable<T>, TableError>
where
T: FromStr,
TableError: From<T::Err>,
{
let path = make_path(table_path, data_path);

let file = File::open(path).map_err(|e| TableError::IO(e.to_string()))?;
let buf_reader = BufReader::new(file);
let mut rdr = csv::Reader::from_reader(buf_reader);

let mut tbl: HashMap<(String, String), Vec<T>> = HashMap::new();

// Read the headers
let headers1: Vec<String> = rdr
.headers()
.map_err(|e| TableError::Csv(e.to_string()))?
.iter()
.map(|s| s.to_string())
.collect();

let mut records = rdr.records();
// Read the second row as the second headers
let headers2: Vec<String> = records
.next()
.ok_or_else(|| TableError::WrongTableFormat("Second row of headers found".to_string()))?
.map_err(|e| TableError::Csv(e.to_string()))?
.iter()
.map(|s| s.to_string())
.collect();

let headers: Vec<_> = headers1.into_iter().zip(headers2).collect();
for header in &headers {
tbl.insert(header.clone(), Vec::new());
}

for result in records {
// The iterator yields Result<StringRecord, Error>, so we check the
// error here.
let record = result.map_err(|e| TableError::Csv(e.to_string()))?;

for (col_idx, value) in record.iter().enumerate() {
let value: T = value.parse()?;
let key = headers.get(col_idx).ok_or_else(|| {
TableError::InvalidFormat(format!(
"Value index ({}) is out of bounds for a table with {} headers.",
col_idx,
headers.len()
))
})?;
tbl.get_mut(key).unwrap().push(value);
}
}

Ok(LoadedVecTable::Two(tbl))
}
8 changes: 5 additions & 3 deletions pywr-schema/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::data_tables::TableError;
use crate::data_tables::{DataTable, TableDataRef, TableError};
use crate::nodes::NodeAttribute;
use crate::timeseries::TimeseriesError;
use thiserror::Error;
Expand Down Expand Up @@ -26,8 +26,10 @@ pub enum SchemaError {
#[error("Pywr core error: {0}")]
#[cfg(feature = "core")]
PywrCore(#[from] pywr_core::PywrError),
#[error("data table error: {0}")]
DataTable(#[from] TableError),
#[error("Error loading data from table `{0}` (column: `{1:?}`, index: `{2:?}`) error: {error}", table_ref.table, table_ref.column, table_ref.index)]
TableRefLoad { table_ref: TableDataRef, error: TableError },
#[error("Error loading table `{table_def:?}` error: {error}")]
TableLoad { table_def: DataTable, error: TableError },
#[error("Circular node reference(s) found.")]
CircularNodeReference,
#[error("Circular parameters reference(s) found. Unable to load the following parameters: {0:?}")]
Expand Down
8 changes: 7 additions & 1 deletion pywr-schema/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ impl Metric {
Self::Parameter(parameter_ref) => parameter_ref.load(network),
Self::Constant { value } => Ok((*value).into()),
Self::Table(table_ref) => {
let value = args.tables.get_scalar_f64(table_ref)?;
let value = args
.tables
.get_scalar_f64(table_ref)
.map_err(|error| SchemaError::TableRefLoad {
table_ref: table_ref.clone(),
error,
})?;
Ok(value.into())
}
Self::Timeseries(ts_ref) => {
Expand Down
2 changes: 1 addition & 1 deletion pywr-schema/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ impl PywrNetwork {

#[cfg(feature = "core")]
pub fn load_tables(&self, data_path: Option<&Path>) -> Result<LoadedTableCollection, SchemaError> {
Ok(LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?)
LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)
}

#[cfg(feature = "core")]
Expand Down
22 changes: 19 additions & 3 deletions pywr-schema/src/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,12 @@ impl ConstantValue<f64> {
pub fn load(&self, tables: &LoadedTableCollection) -> Result<f64, SchemaError> {
match self {
Self::Literal(v) => Ok(*v),
Self::Table(tbl_ref) => Ok(tables.get_scalar_f64(tbl_ref)?),
Self::Table(tbl_ref) => tables
.get_scalar_f64(tbl_ref)
.map_err(|error| SchemaError::TableRefLoad {
table_ref: tbl_ref.clone(),
error,
}),
}
}
}
Expand All @@ -736,7 +741,12 @@ impl ConstantValue<usize> {
pub fn load(&self, tables: &LoadedTableCollection) -> Result<usize, SchemaError> {
match self {
Self::Literal(v) => Ok(*v),
Self::Table(tbl_ref) => Ok(tables.get_scalar_usize(tbl_ref)?),
Self::Table(tbl_ref) => tables
.get_scalar_usize(tbl_ref)
.map_err(|error| SchemaError::TableRefLoad {
table_ref: tbl_ref.clone(),
error,
}),
}
}
}
Expand Down Expand Up @@ -867,7 +877,13 @@ impl ConstantFloatVec {
pub fn load(&self, tables: &LoadedTableCollection) -> Result<Vec<f64>, SchemaError> {
match self {
Self::Literal(v) => Ok(v.clone()),
Self::Table(tbl_ref) => Ok(tables.get_vec_f64(tbl_ref)?.clone()),
Self::Table(tbl_ref) => tables
.get_vec_f64(tbl_ref)
.cloned()
.map_err(|error| SchemaError::TableRefLoad {
table_ref: tbl_ref.clone(),
error,
}),
}
}
}
Expand Down
Loading