Skip to content

Commit

Permalink
create seperate timeseries error enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Batch21 committed Mar 3, 2024
1 parent 1b5c5b1 commit 72fe90e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 32 deletions.
15 changes: 3 additions & 12 deletions pywr-schema/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::data_tables::TableError;
use crate::nodes::NodeAttribute;
use crate::timeseries::TimeseriesError;
use polars::error::PolarsError;
use pyo3::exceptions::PyRuntimeError;
use pyo3::PyErr;
Expand Down Expand Up @@ -29,18 +30,6 @@ pub enum SchemaError {
PywrCore(#[from] pywr_core::PywrError),
#[error("data table error: {0}")]
DataTable(#[from] TableError),
#[error("Timeseries '{0} not found")]
TimeseriesNotFound(String),
#[error("The duration of timeseries '{0}' could not be determined.")]
TimeseriesDurationNotFound(String),
#[error("Column '{col}' not found in timeseries input '{name}'")]
ColumnNotFound { col: String, name: String },
#[error("Timeseries provider '{provider}' does not support '{fmt}' file types")]
TimeseriesUnsupportedFileFormat { provider: String, fmt: String },
#[error("Timeseries provider '{provider}' cannot parse file: '{path}'")]
TimeseriesUnparsableFileFormat { provider: String, path: String },
#[error("Polars error: {0}")]
PolarsError(#[from] PolarsError),
#[error("Circular node reference(s) found.")]
CircularNodeReference,
#[error("Circular parameters reference(s) found.")]
Expand All @@ -67,6 +56,8 @@ pub enum SchemaError {
InvalidRollingWindow { name: String },
#[error("Failed to load parameter {name}: {error}")]
LoadParameter { name: String, error: String },
#[error("Timeseries error: {0}")]
Timeseries(#[from] TimeseriesError),
}

impl From<SchemaError> for PyErr {
Expand Down
10 changes: 5 additions & 5 deletions pywr-schema/src/timeseries/align_and_resample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use polars::{prelude::*, series::ops::NullBehavior};
use pywr_core::models::ModelDomain;
use std::{cmp::Ordering, ops::Deref};

use crate::SchemaError;
use crate::timeseries::TimeseriesError;

pub fn align_and_resample(
name: &str,
df: DataFrame,
time_col: &str,
domain: &ModelDomain,
) -> Result<DataFrame, SchemaError> {
) -> Result<DataFrame, TimeseriesError> {
// Ensure type of time column is datetime and that it is sorted
let df = df
.clone()
Expand All @@ -36,7 +36,7 @@ pub fn align_and_resample(

let timeseries_duration = match durations.get(0) {
Some(duration) => duration,
None => return Err(SchemaError::TimeseriesDurationNotFound(name.to_string())),
None => return Err(TimeseriesError::TimeseriesDurationNotFound(name.to_string())),
};

let model_duration = domain
Expand Down Expand Up @@ -81,13 +81,13 @@ pub fn align_and_resample(
Ok(df)
}

fn slice_start(df: DataFrame, time_col: &str, domain: &ModelDomain) -> Result<DataFrame, SchemaError> {
fn slice_start(df: DataFrame, time_col: &str, domain: &ModelDomain) -> Result<DataFrame, TimeseriesError> {
let start = domain.time().first_timestep().date;
let df = df.clone().lazy().filter(col(time_col).gt_eq(lit(start))).collect()?;
Ok(df)
}

fn slice_end(df: DataFrame, time_col: &str, domain: &ModelDomain) -> Result<DataFrame, SchemaError> {
fn slice_end(df: DataFrame, time_col: &str, domain: &ModelDomain) -> Result<DataFrame, TimeseriesError> {
let end = domain.time().last_timestep().date;
let df = df.clone().lazy().filter(col(time_col).lt_eq(lit(end))).collect()?;
Ok(df)
Expand Down
42 changes: 32 additions & 10 deletions pywr-schema/src/timeseries/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,39 @@ mod align_and_resample;
mod polars_dataset;

use ndarray::Array2;
use polars::error::PolarsError;
use polars::prelude::DataType::Float64;
use polars::prelude::{DataFrame, Float64Type, IndexOrder};
use pywr_core::models::ModelDomain;
use pywr_core::parameters::{Array1Parameter, Array2Parameter, ParameterIndex};
use pywr_core::PywrError;
use std::{collections::HashMap, path::Path};
use thiserror::Error;

use crate::{parameters::ParameterMeta, SchemaError};
use crate::parameters::ParameterMeta;

use self::polars_dataset::PolarsDataset;

#[derive(Error, Debug)]
pub enum TimeseriesError {
#[error("Timeseries '{0} not found")]
TimeseriesNotFound(String),
#[error("The duration of timeseries '{0}' could not be determined.")]
TimeseriesDurationNotFound(String),
#[error("Column '{col}' not found in timeseries input '{name}'")]
ColumnNotFound { col: String, name: String },
#[error("Timeseries provider '{provider}' does not support '{fmt}' file types")]
TimeseriesUnsupportedFileFormat { provider: String, fmt: String },
#[error("Timeseries provider '{provider}' cannot parse file: '{path}'")]
TimeseriesUnparsableFileFormat { provider: String, path: String },
#[error("A scenario group with name '{0}' was not found")]
ScenarioGroupNotFound(String),
#[error("Polars error: {0}")]
PolarsError(#[from] PolarsError),
#[error("Pywr core error: {0}")]
PywrCore(#[from] pywr_core::PywrError),
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[serde(tag = "type")]
enum TimeseriesProvider {
Expand All @@ -28,7 +50,7 @@ pub struct Timeseries {
}

impl Timeseries {
pub fn load(&self, domain: &ModelDomain, data_path: Option<&Path>) -> Result<DataFrame, SchemaError> {
pub fn load(&self, domain: &ModelDomain, data_path: Option<&Path>) -> Result<DataFrame, TimeseriesError> {
match &self.provider {
TimeseriesProvider::Polars(dataset) => dataset.load(self.meta.name.as_str(), data_path, domain),
TimeseriesProvider::Pandas => todo!(),
Expand All @@ -45,7 +67,7 @@ impl LoadedTimeseriesCollection {
timeseries_defs: Option<&[Timeseries]>,
domain: &ModelDomain,
data_path: Option<&Path>,
) -> Result<Self, SchemaError> {
) -> Result<Self, TimeseriesError> {
let mut timeseries = HashMap::new();
if let Some(timeseries_defs) = timeseries_defs {
for ts in timeseries_defs {
Expand All @@ -62,11 +84,11 @@ impl LoadedTimeseriesCollection {
network: &mut pywr_core::network::Network,
name: &str,
col: &str,
) -> Result<ParameterIndex, SchemaError> {
) -> Result<ParameterIndex, TimeseriesError> {
let df = self
.timeseries
.get(name)
.ok_or(SchemaError::TimeseriesNotFound(name.to_string()))?;
.ok_or(TimeseriesError::TimeseriesNotFound(name.to_string()))?;
let series = df.column(col)?;

let array = series.cast(&Float64)?.f64()?.to_ndarray()?.to_owned();
Expand All @@ -79,7 +101,7 @@ impl LoadedTimeseriesCollection {
let p = Array1Parameter::new(&name, array, None);
Ok(network.add_parameter(Box::new(p))?)
}
_ => Err(SchemaError::PywrCore(e)),
_ => Err(TimeseriesError::PywrCore(e)),
},
}
}
Expand All @@ -90,16 +112,16 @@ impl LoadedTimeseriesCollection {
name: &str,
domain: &ModelDomain,
scenario: &str,
) -> Result<ParameterIndex, SchemaError> {
) -> Result<ParameterIndex, TimeseriesError> {
let scenario_group_index = domain
.scenarios()
.group_index(scenario)
.ok_or(SchemaError::ScenarioGroupNotFound(scenario.to_string()))?;
.ok_or(TimeseriesError::ScenarioGroupNotFound(scenario.to_string()))?;

let df = self
.timeseries
.get(name)
.ok_or(SchemaError::TimeseriesNotFound(name.to_string()))?;
.ok_or(TimeseriesError::TimeseriesNotFound(name.to_string()))?;

let array: Array2<f64> = df.to_ndarray::<Float64Type>(IndexOrder::default()).unwrap();
let name = format!("timeseries.{}_{}", name, scenario);
Expand All @@ -111,7 +133,7 @@ impl LoadedTimeseriesCollection {
let p = Array2Parameter::new(&name, array, scenario_group_index, None);
Ok(network.add_parameter(Box::new(p))?)
}
_ => Err(SchemaError::PywrCore(e)),
_ => Err(TimeseriesError::PywrCore(e)),
},
}
}
Expand Down
15 changes: 10 additions & 5 deletions pywr-schema/src/timeseries/polars_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::{Path, PathBuf};
use polars::{frame::DataFrame, prelude::*};
use pywr_core::models::ModelDomain;

use crate::SchemaError;
use crate::timeseries::TimeseriesError;

use super::align_and_resample::align_and_resample;

Expand All @@ -14,7 +14,12 @@ pub struct PolarsDataset {
}

impl PolarsDataset {
pub fn load(&self, name: &str, data_path: Option<&Path>, domain: &ModelDomain) -> Result<DataFrame, SchemaError> {
pub fn load(
&self,
name: &str,
data_path: Option<&Path>,
domain: &ModelDomain,
) -> Result<DataFrame, TimeseriesError> {
let fp = if self.url.is_absolute() {
self.url.clone()
} else if let Some(data_path) = data_path {
Expand All @@ -36,21 +41,21 @@ impl PolarsDataset {
todo!()
}
Some(other_ext) => {
return Err(SchemaError::TimeseriesUnsupportedFileFormat {
return Err(TimeseriesError::TimeseriesUnsupportedFileFormat {
provider: "polars".to_string(),
fmt: other_ext.to_string(),
})
}
None => {
return Err(SchemaError::TimeseriesUnparsableFileFormat {
return Err(TimeseriesError::TimeseriesUnparsableFileFormat {
provider: "polars".to_string(),
path: self.url.to_string_lossy().to_string(),
})
}
}
}
None => {
return Err(SchemaError::TimeseriesUnparsableFileFormat {
return Err(TimeseriesError::TimeseriesUnparsableFileFormat {
provider: "polars".to_string(),
path: self.url.to_string_lossy().to_string(),
})
Expand Down

0 comments on commit 72fe90e

Please sign in to comment.