diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 43a5334..17c11b7 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -28,16 +28,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- - uses: conda-incubator/setup-miniconda@v2
+ - uses: mamba-org/setup-micromamba@v1
with:
activate-environment: test
environment-file: environment.yml
auto-activate-base: false
- - name: conda check
- shell: bash -l {0}
- run: |
- conda info
- conda list
- name: install hat package
shell: bash -l {0}
run: pip install .
diff --git a/.github/workflows/on-push.yml b/.github/workflows/on-push.yml
index c796d57..968e5fa 100644
--- a/.github/workflows/on-push.yml
+++ b/.github/workflows/on-push.yml
@@ -32,16 +32,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- - uses: conda-incubator/setup-miniconda@v2
+ - uses: mamba-org/setup-micromamba@v1
with:
activate-environment: test
environment-file: environment.yml
auto-activate-base: false
- - name: Conda check
- shell: bash -l {0}
- run: |
- conda info
- conda list
- name: install hat package
shell: bash -l {0}
run: pip install .
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 30af686..5b536b2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,7 +7,7 @@ repos:
rev: 23.3.0
hooks:
- id: black
- - repo: https://github.com/PyCQA/flake8
- rev: 6.0.0
- hooks:
- - id: flake8
+# - repo: https://github.com/PyCQA/flake8
+# rev: 6.0.0
+# hooks:
+# - id: flake8
diff --git a/README.md b/README.md
index 0d825dc..69cc149 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@ Interfaces and functionality are likely to change, and the project itself may be
Clone source code repository
- $ git clone git@github.com:ecmwf-projects/hat.git
+ $ git clone https://github.com/ecmwf/hat.git
Create conda python environment
@@ -63,4 +63,4 @@ does it submit to any jurisdiction.
### Citing
-In publications, please use a link to this repository (https://github.com/ecmwf/hat) and its documentation (https://hydro-analysis-toolkit.readthedocs.io)
\ No newline at end of file
+In publications, please use a link to this repository (https://github.com/ecmwf/hat) and its documentation (https://hydro-analysis-toolkit.readthedocs.io)
diff --git a/environment.yml b/environment.yml
index 6b28254..1c2e5c2 100644
--- a/environment.yml
+++ b/environment.yml
@@ -2,21 +2,21 @@ name: hat
channels:
- conda-forge
dependencies:
- - python=3.10
+ - python<=3.10
- netCDF4
- eccodes
- cfgrib
- - cftime
- geopandas
- xarray
- plotly
- matplotlib
- - jupyterlab
+ - jupyter
- tqdm
- typer
- humanize
- - folium
- typer
+ - ipyleaflet
+ - ipywidgets
- pip
# - pytest
# - mkdocs
diff --git a/hat/clock.py b/hat/clock.py
deleted file mode 100644
index 409fa97..0000000
--- a/hat/clock.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import datetime
-import functools
-import time
-
-import humanize
-
-
-def digital_clock(func):
- "A quiet decorator for timing functions"
-
- @functools.wraps(func)
- def clocked(*args, **kwargs):
- # time
- t0 = time.perf_counter()
- result = func(*args, **kwargs)
- elapsed = time.perf_counter() - t0
- human_readable_time = humanize.naturaldelta(datetime.timedelta(seconds=elapsed))
-
- # name
- name = func.__name__
-
- print(f"{name}() took {human_readable_time}")
- return result
-
- # return function with timing decorator
- return clocked
-
-
-def clock(func):
- "A verbose decorator for timing functions"
-
- @functools.wraps(func)
- def clocked(*args, **kwargs):
- # time
- t0 = time.perf_counter()
- result = func(*args, **kwargs)
- elapsed = time.perf_counter() - t0
- human_readable_time = humanize.naturaldelta(datetime.timedelta(seconds=elapsed))
-
- # name
- name = func.__name__
-
- # arguments
- arg_str = ", ".join(repr(arg) for arg in args)
- if arg_str == "":
- arg_str = "no arguments"
-
- # keywords
- pairs = [f"{k}={w}" for k, w in sorted(kwargs.items())]
- key_str = ", ".join(pairs)
- if key_str == "":
- key_str = "no keywords"
-
- print(
- f"""{name}() took {human_readable_time} to run
- with following inputs {arg_str} and {key_str}"""
- )
- return result
-
- # return function with timing decorator
- return clocked
-
-
-if __name__ == "__main__":
-
- @clock
- def test(arg1, keyword=False):
- time.sleep(0.1)
- pass
-
- test(1, keyword="hello")
diff --git a/hat/data.py b/hat/data.py
index 4f25eee..d1ff6e0 100644
--- a/hat/data.py
+++ b/hat/data.py
@@ -37,6 +37,9 @@ def get_tmpdir():
else:
tmpdir = TemporaryDirectory().name
+ # Ensure the directory exists
+ os.makedirs(tmpdir, exist_ok=True)
+
return tmpdir
@@ -257,8 +260,8 @@ def save_dataset_to_netcdf(ds: xr.Dataset, fpath: str):
ds.to_netcdf(fpath)
-def find_main_var(ds):
- variable_names = [k for k in ds.variables if len(ds.variables[k].dims) >= 3]
+def find_main_var(ds, min_dim=3):
+ variable_names = [k for k in ds.variables if len(ds.variables[k].dims) >= min_dim]
if len(variable_names) > 1:
raise Exception("More than one variable in dataset")
elif len(variable_names) == 0:
diff --git a/hat/filters.py b/hat/filters.py
index 88cdb07..bc0aebb 100644
--- a/hat/filters.py
+++ b/hat/filters.py
@@ -144,7 +144,7 @@ def filter_dataframe(df, filters: str):
return df
-def filter_timeseries(sims_ds: xr.Dataset, obs_ds: xr.Dataset, threshold=80):
+def filter_timeseries(sims_ds: xr.DataArray, obs_ds: xr.DataArray, threshold=80):
"""Clean the simulation and observation timeseries
Only keep..
@@ -159,30 +159,36 @@ def filter_timeseries(sims_ds: xr.Dataset, obs_ds: xr.Dataset, threshold=80):
matching_stations = sorted(
set(sims_ds.station.values).intersection(obs_ds.station.values)
)
+ print(len(matching_stations))
sims_ds = sims_ds.sel(station=matching_stations)
obs_ds = obs_ds.sel(station=matching_stations)
+ obs_ds = obs_ds.sel(time=sims_ds.time)
+
+ obs_ds = obs_ds.dropna(dim="station", how="all")
+ sims_ds = sims_ds.sel(station=obs_ds.station)
# Only keep observations in the same time period as the simulations
- obs_ds = obs_ds.where(sims_ds.time == obs_ds.time, drop=True)
+ # obs_ds = obs_ds.where(sims_ds.time == obs_ds.time, drop=True)
# Only keep obsevations with enough valid data in this timeperiod
# discharge data
- dis = obs_ds.obsdis
+ print(sims_ds)
+ print(obs_ds)
# Replace negative values with NaN
- dis = dis.where(dis >= 0)
+ # dis = dis.where(dis >= 0)
- # Percentage of valid discharge data at each point in time
- valid_percent = dis.notnull().mean(dim="time") * 100
+ # # Percentage of valid discharge data at each point in time
+ # valid_percent = dis.notnull().mean(dim="time") * 100
- # Boolean index of where there is enough valid data
- enough_observation_data = valid_percent > threshold
+ # # Boolean index of where there is enough valid data
+ # enough_observation_data = valid_percent > threshold
- # keep where there is enough observation data
- obs_ds = obs_ds.where(enough_observation_data, drop=True)
+ # # keep where there is enough observation data
+ # obs_ds = obs_ds.where(enough_observation_data, drop=True)
- # keep simulation that match remaining observations
- sims_ds = sims_ds.where(enough_observation_data, drop=True)
+ # # keep simulation that match remaining observations
+ # sims_ds = sims_ds.where(enough_observation_data, drop=True)
return (sims_ds, obs_ds)
diff --git a/hat/graphs.py b/hat/graphs.py
deleted file mode 100644
index 8f6ba39..0000000
--- a/hat/graphs.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import pandas as pd
-import plotly.express as px
-
-
-def graph_sims_and_obs(
- sims,
- obs,
- ID,
- sims_data_name="simulation_timeseries",
- obs_data_name="obsdis",
- height=500,
- width=1200,
-):
- # observations, simulations, time
- o = obs.sel(station=ID)[obs_data_name].values
- s = sims.sel(station=ID)[sims_data_name].values
- t = obs.sel(station=ID).time.values
-
- df = pd.DataFrame({"time": t, "simulations": s, "observations": o})
- fig = px.line(
- df,
- x="time",
- y=["simulations", "observations"],
- title="Simulations & Observations",
- )
- fig.data[0].line.color = "#34eb7d"
- fig.data[1].line.color = "#3495eb"
- fig.update_layout(height=height, width=width)
- fig.update_yaxes(title_text="discharge")
- fig.show()
diff --git a/hat/hydrostats.py b/hat/hydrostats.py
index d03c3fe..79238a9 100644
--- a/hat/hydrostats.py
+++ b/hat/hydrostats.py
@@ -1,30 +1,25 @@
"""high level python api for hydrological statistics"""
from typing import List
-import folium
-import geopandas as gpd
-import pandas as pd
+import numpy as np
import xarray as xr
-from branca.colormap import linear
-from folium.plugins import Fullscreen
from hat import hydrostats_functions
def run_analysis(
functions: List,
- sims_ds: xr.Dataset,
- obs_ds: xr.Dataset,
- sims_var_name="simulation_timeseries",
- obs_var_name="obsdis",
+ sims_ds: xr.DataArray,
+ obs_ds: xr.DataArray,
) -> xr.Dataset:
- """Run statistical analysis on simulation and observation timeseries"""
+ """
+ Run statistical analysis on simulation and observation timeseries
+ """
# list of stations
stations = sims_ds.coords["station"].values
- # Create an empty DataFrame with stations as the index
- df = pd.DataFrame(index=stations)
+ ds = xr.Dataset()
# For each statistical function
for name in functions:
@@ -33,95 +28,19 @@ def run_analysis(
# do timeseries analysis for each station
# (using a "numpy in, numpy out" function)
- statistics = {}
+ statistics = []
for station in stations:
- sims = sims_ds.sel(station=station)[sims_var_name].to_numpy()
- obs = obs_ds.sel(station=station)[obs_var_name].to_numpy()
- statistics[station] = func(sims, obs)
+ sims = sims_ds.sel(station=station).to_numpy()
+ obs = obs_ds.sel(station=station).to_numpy()
- # 1D Series of statistics (i.e. scalar value per station)
- statistics_series = pd.Series(statistics, name=name)
+ stat = func(sims, obs)
+ if stat is None:
+ print(f"Warning! All NaNs for station {station}")
+ stat = 0
+ statistics += [stat]
+ statistics = np.array(statistics)
# Add the Series to the DataFrame
- df[name] = statistics_series
-
- # Convert the DataFrame to an xarray Dataset
- ds = df.to_xarray()
- ds = ds.rename({"index": "station"})
- ds["longitude"] = ("station", sims_ds.coords["longitude"].data)
- ds["latitude"] = ("station", sims_ds.coords["latitude"].data)
+ ds[name] = xr.DataArray(statistics, coords={"station": stations})
return ds
-
-
-def display_map(ds: xr.Dataset, name: str, minv: float = 0, maxv: float = 1):
- # xarray to geopandas
- gdf = gpd.GeoDataFrame(
- ds.to_dataframe(),
- geometry=gpd.points_from_xy(ds["longitude"], ds["latitude"]),
- crs="epsg:4326",
- )
- gdf["station_id"] = gdf.index
- gdf = gdf[~gdf[name].isnull()]
-
- # Create a color map
- colormap = linear.Blues_09.scale(minv, maxv)
-
- # Define style function
- def style_function(feature):
- property = feature["properties"][name]
- return {"fillOpacity": 0.7, "weight": 0, "fillColor": colormap(property)}
-
- m = folium.Map(location=[48, 5], zoom_start=5, prefer_canvas=True, tiles=None)
- _ = folium.GeoJson(
- gdf,
- marker=folium.CircleMarker(fillColor="white", fillOpacity=0.5, radius=5),
- name=name,
- style_function=style_function,
- tooltip=folium.GeoJsonTooltip(fields=["station_id", name]),
- popup=folium.GeoJsonPopup(fields=["station_id", name]),
- ).add_to(m)
-
- # Add the CartoDB Positron tileset as a layer
- cartodb_positron = folium.TileLayer(
- tiles="CartoDB Dark_Matter",
- name="Dark",
- overlay=False,
- control=True,
- )
- cartodb_positron.add_to(m)
-
- # Add the CartoDB Positron tileset as a layer
- cartodb_positron = folium.TileLayer(
- tiles="CartoDB Positron",
- name="Light",
- overlay=False,
- control=True,
- )
- cartodb_positron.add_to(m)
-
- # Add OpenStreetMap layer
- open_street_map = folium.TileLayer(
- tiles="OpenStreetMap",
- name="Open Street Map",
- overlay=False,
- control=True,
- )
- open_street_map.add_to(m)
-
- # Add the satellite layer
- esri_satellite = folium.TileLayer(
- tiles="""https://server.arcgisonline.com/ArcGIS/rest/services/
- World_Imagery/MapServer/tile/{z}/{y}/{x}""",
- attr="Esri",
- name="Satellite",
- overlay=False,
- control=True,
- )
- esri_satellite.add_to(m)
-
- # add controls
- folium.LayerControl().add_to(m)
- Fullscreen().add_to(m)
-
- return m
diff --git a/hat/hydrostats_decorators.py b/hat/hydrostats_decorators.py
index cfe6817..e0a219a 100644
--- a/hat/hydrostats_decorators.py
+++ b/hat/hydrostats_decorators.py
@@ -49,7 +49,7 @@ def wrapper(arr1: np.ndarray, arr2: np.ndarray):
nan_mask2 = np.isnan(arr2)
filtered_mask = ~(nan_mask1 | nan_mask2)
if not np.any(filtered_mask):
- raise ValueError("All elements are NaN")
+ return None
return func(arr1[filtered_mask], arr2[filtered_mask])
return wrapper
diff --git a/hat/images.py b/hat/images.py
deleted file mode 100644
index 14ebeae..0000000
--- a/hat/images.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import matplotlib
-import numpy as np
-from quicklook import quicklook
-
-
-def arr_to_image(arr: np.array) -> np.array:
- """modify array so that it is optimized for viewing"""
-
- # image array
- img = np.array(arr)
-
- img = quicklook.replace_nan(img)
- img = quicklook.percentile_clip(img, 2)
- img = quicklook.bytescale(img)
- img = quicklook.reshape_array(img)
-
- return img
-
-
-def numpy_to_png(
- arr: np.array, dim="time", index="somedate", fpath="image.png"
-) -> None:
- """Save numpy array to png"""
-
- # image from array
- img = arr_to_image(arr)
-
- # save to file
- matplotlib.image.imsave(fpath, img)
diff --git a/hat/interactive/__init__.py b/hat/interactive/__init__.py
new file mode 100644
index 0000000..99567ee
--- /dev/null
+++ b/hat/interactive/__init__.py
@@ -0,0 +1,12 @@
+# flake8: noqa
+from .explorers import TimeSeriesExplorer
+from .leaflet import LeafletMap, PyleafletColormap
+from .widgets import (
+ DataFrameWidget,
+ HTMLTableWidget,
+ MetaDataWidget,
+ PlotlyWidget,
+ StatisticsWidget,
+ Widget,
+ WidgetsManager,
+)
diff --git a/hat/interactive/explorers.py b/hat/interactive/explorers.py
new file mode 100644
index 0000000..0c88d4c
--- /dev/null
+++ b/hat/interactive/explorers.py
@@ -0,0 +1,358 @@
+import os
+
+import ipywidgets
+import pandas as pd
+import xarray as xr
+from IPython.display import display
+
+from hat.interactive.leaflet import LeafletMap, PyleafletColormap
+from hat.interactive.widgets import (
+ MetaDataWidget,
+ PlotlyWidget,
+ StatisticsWidget,
+ WidgetsManager,
+)
+from hat.observations import read_station_metadata_file
+
+
+def prepare_simulations_data(simulations, sims_var_name):
+ """
+ Process simulations and put then in a dictionnary of xarray data arrays.
+
+ Parameters
+ ----------
+ simulations : dict
+ A dictionary of paths to the simulation netCDF files, with the keys
+ being the simulation names.
+ sims_var_name : str
+ The name of the variable in the simulation netCDF files that contains
+ the simulated values.
+
+ Returns
+ -------
+ dict
+ A dictionary of xarray data arrays containing the simulation data.
+
+ """
+ # If simulations is a dictionary, load data for each experiment
+ sim_ds = {}
+ for exp, path in simulations.items():
+ # Expanding the tilde
+ expanded_path = os.path.expanduser(path)
+
+ if os.path.isfile(expanded_path): # Check if it's a file
+ ds = xr.open_dataset(expanded_path)
+
+ sim_ds[exp] = ds[sims_var_name]
+
+ return sim_ds
+
+
+def prepare_observations_data(observations, sim_ds, obs_var_name):
+ """
+ Process observation raw dataset to a standard xarray dataset.
+ The observation dataset can be either a csv file or a netcdf file.
+ The observation dataset is subsetted based on the time values of the
+ simulation dataset.
+
+ Parameters
+ ----------
+ observations : str
+ The path to the observation netCDF file.
+ sim_ds : dict or xarray.Dataset
+ A dictionary of xarray datasets containing the simulation data, or a
+ single xarray dataset.
+ obs_var_name : str
+ The name of the variable in the observation netCDF file that contains
+ the observed values.
+
+ Returns
+ -------
+ xarray.Dataset
+ An xarray dataset containing the observation data.
+
+ Raises
+ ------
+ ValueError
+ If the file format of the observations file is not supported.
+
+ """
+ file_extension = os.path.splitext(observations)[-1].lower()
+
+ if file_extension == ".csv":
+ obs_df = pd.read_csv(observations, parse_dates=["Timestamp"])
+ obs_melted = obs_df.melt(
+ id_vars="Timestamp", var_name="station", value_name=obs_var_name
+ )
+ # Convert the melted DataFrame to xarray Dataset
+ obs_ds = obs_melted.set_index(["Timestamp", "station"]).to_xarray()
+ obs_ds = obs_ds.rename({"Timestamp": "time"})
+ elif file_extension == ".nc":
+ obs_ds = xr.open_dataset(observations)
+ else:
+ raise ValueError("Unsupported file format for observations.")
+
+ # Subset obs_ds based on sim_ds time values
+ if isinstance(sim_ds, xr.Dataset):
+ time_values = sim_ds["time"].values
+ elif isinstance(sim_ds, dict):
+ # Use the first dataset in the dictionary to determine time values
+ first_dataset = next(iter(sim_ds.values()))
+ time_values = first_dataset["time"].values
+ else:
+ raise ValueError("Unexpected type for sim_ds")
+
+ obs_ds = obs_ds[obs_var_name].sel(time=time_values)
+ return obs_ds
+
+
+def find_common_stations(station_index, stations_metadata, obs_ds, sim_ds, statistics):
+ """
+ Find common stations between observations, simulations and station
+ metadata.
+
+ Parameters
+ ----------
+ station_index : str
+ The name of the column in the station metadata file that contains the
+ station IDs.
+ stations_metadata : pandas.DataFrame
+ A pandas DataFrame containing the station metadata.
+ obs_ds : xarray.Dataset
+ An xarray dataset containing the observation data.
+ sim_ds : dict or xarray.Dataset
+ A dictionary of xarray data arrays containing the simulation data.
+ statistics : dict
+ A dictionary of xarray data arrays containing the statistics data.
+
+ Returns
+ -------
+ list
+ A list of common station IDs.
+
+ """
+ ids = []
+ ids += [list(obs_ds["station"].values)]
+ ids += [list(ds["station"].values) for ds in sim_ds.values()]
+ ids += [stations_metadata[station_index]]
+ if statistics:
+ ids += [list(ds["station"].values) for ds in statistics.values()]
+
+ common_ids = None
+ for id in ids:
+ if common_ids is None:
+ common_ids = set(id)
+ else:
+ common_ids = set(id) & common_ids
+ return list(common_ids)
+
+
+class TimeSeriesExplorer:
+ """
+ Initialize the interactive map with configurations and data sources.
+ """
+
+ def __init__(self, config):
+ """
+ Initializes an instance of the Explorer class.
+
+ Parameters
+ ----------
+ config : dict
+ A dictionary containing the configuration parameters for the
+ Explorer.
+
+ Notes
+ -----
+ This method initializes an instance of the Explorer class with the
+ given configuration parameters.
+ The configuration parameters should be provided as a dictionary with
+ the following keys:
+
+ - stations : str
+ The path to the station metadata file.
+ - observations : str
+ The path to the observation netCDF file.
+ - simulations : dict
+ A dictionary of paths to the simulation netCDF files, with the keys
+ being the simulation names.
+ - statistics : dict, optional
+ A dictionary of paths to the statistics netCDF files, with the keys
+ being the simulation names.
+ - station_coordinates : list of str
+ The names of the columns in the station metadata file that contain
+ the station coordinates.
+ - station_epsg : int
+ The EPSG code of the coordinate reference system used by the
+ station coordinates.
+ - station_filters : dict
+ A dictionary of filters to apply to the station metadata file.
+ - sims_var_name : str
+ The name of the variable in the simulation netCDF files that
+ contains the simulated values.
+ - obs_var_name : str
+ The name of the variable in the observation netCDF file that
+ contains the observed values.
+ - station_id_column_name : str
+ The name of the column in the station metadata file that contains
+ the station IDs.
+
+ Raises
+ ------
+ AssertionError
+ If there is a mismatch between the keys of the statistics netCDF
+ files and the simulation netCDF files.
+
+ """
+ self.config = config
+
+ self.stations_metadata = read_station_metadata_file(
+ fpath=config["stations"],
+ coord_names=config["station_coordinates"],
+ epsg=config["station_epsg"],
+ filters=config["station_filters"],
+ )
+
+ # Use the external functions to prepare data
+ sim_ds = prepare_simulations_data(
+ config["simulations"], config["sims_var_name"]
+ )
+ obs_ds = prepare_observations_data(
+ config["observations"], sim_ds, config["obs_var_name"]
+ )
+
+ # set station index
+ self.station_index = config["station_id_column_name"]
+
+ # Retrieve statistics from the statistics netcdf input
+ self.statistics = {}
+ stats = config.get("statistics")
+ if stats is not None:
+ for name, path in stats.items():
+ self.statistics[name] = xr.open_dataset(path)
+
+ # Ensure the keys of self.statistics match the keys of self.sim_ds
+ assert set(self.statistics.keys()) == set(
+ sim_ds.keys()
+ ), "Mismatch between statistics and simulations keys."
+
+ # find common station ids between metadata, observation and simulations
+ common_ids = find_common_stations(
+ self.station_index,
+ self.stations_metadata,
+ obs_ds,
+ sim_ds,
+ self.statistics,
+ )
+
+ print(f"Found {len(common_ids)} common stations")
+ self.stations_metadata = self.stations_metadata.loc[
+ self.stations_metadata[self.station_index].isin(common_ids)
+ ]
+ obs_ds = obs_ds.sel(station=common_ids)
+ for sim, ds in sim_ds.items():
+ sim_ds[sim] = ds.sel(station=common_ids)
+
+ # Create loading widget
+ self.loading_widget = ipywidgets.Label(value="")
+
+ # Title label
+ self.title_label = ipywidgets.Label(
+ "Interactive Map Visualisation for Hydrological Model Performance",
+ layout=ipywidgets.Layout(justify_content="center"),
+ style={"font_weight": "bold", "font_size": "24px", "font_family": "Arial"},
+ )
+
+ # Create the interactive widgets
+ datasets = sim_ds
+ datasets["obs"] = obs_ds
+ widgets = {}
+ widgets["plot"] = PlotlyWidget(datasets)
+ widgets["stats"] = StatisticsWidget(self.statistics)
+ widgets["meta"] = MetaDataWidget(self.stations_metadata, self.station_index)
+ self.widgets = WidgetsManager(
+ widgets, config["station_id_column_name"], self.loading_widget
+ )
+
+ # Create the main leaflet map
+ self.leafletmap = LeafletMap()
+
+ def create_frame(self):
+ """
+ Initialize the layout of the widgets for the map visualization.
+
+ Returns
+ -------
+ ipywidgets.VBox
+ A vertical box containing the layout elements for the map
+ visualization.
+
+ """
+ # Layouts 2
+ main_layout = ipywidgets.Layout(
+ justify_content="space-around",
+ align_items="stretch",
+ spacing="2px",
+ width="1000px",
+ )
+ left_layout = ipywidgets.Layout(
+ justify_content="space-around",
+ align_items="center",
+ spacing="2px",
+ width="40%",
+ )
+ right_layout = ipywidgets.Layout(
+ justify_content="center", align_items="center", spacing="2px", width="60%"
+ )
+
+ # Frames
+ top_left_frame = self.leafletmap.output(left_layout)
+ top_right_frame = ipywidgets.VBox(
+ [self.widgets["plot"].output, self.widgets["stats"].output],
+ layout=right_layout,
+ )
+ main_top_frame = ipywidgets.HBox([top_left_frame, top_right_frame])
+
+ # Main layout
+ main_frame = ipywidgets.VBox(
+ [self.title_label, main_top_frame, self.widgets["meta"].output],
+ layout=main_layout,
+ )
+ return main_frame
+
+ def plot(self, colorby=None, sim=None, limits=None, mp_colormap="viridis"):
+ """
+ Plot the stations markers colored by a given metric.
+
+ Parameters
+ ----------
+ colorby : str, optional
+ The name of the metric to color the stations by.
+ sim : str, optional
+ The name of the simulation to use for the metric.
+ limits : list, optional
+ A list of two values representing the minimum and maximum values
+ for the color bar.
+ mp_colormap : str, optional
+ The name of the matplotlib colormap to use for the color bar.
+
+ """
+ # create colormap from statistics
+ stats = None
+ if self.statistics and colorby is not None and sim is not None:
+ stats = self.statistics[sim][colorby]
+ colormap = PyleafletColormap(self.config, stats, mp_colormap, limits)
+
+ # add layer to the leaflet map
+ self.leafletmap.add_geolayer(
+ self.stations_metadata,
+ colormap,
+ self.widgets,
+ self.config["station_coordinates"],
+ )
+
+ # Initialize frame elements
+ frame = self.create_frame()
+
+ # Display the main layout
+ display(frame)
diff --git a/hat/interactive/leaflet.py b/hat/interactive/leaflet.py
new file mode 100644
index 0000000..10ff223
--- /dev/null
+++ b/hat/interactive/leaflet.py
@@ -0,0 +1,248 @@
+import json
+
+import ipyleaflet
+import ipywidgets
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def _compute_bounds(stations_metadata, coord_names):
+ """Compute the bounds of the map based on the stations metadata."""
+
+ lon_column = coord_names[0]
+ lat_column = coord_names[1]
+
+ lons = stations_metadata[lon_column].values
+ lats = stations_metadata[lat_column].values
+
+ min_lat, max_lat = min(lats), max(lats)
+ min_lon, max_lon = min(lons), max(lons)
+
+ return [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))]
+
+
+class LeafletMap:
+ """
+ A class for creating interactive leaflet maps.
+
+ Parameters
+ ----------
+ basemap : ipyleaflet.basemaps, optional
+ The basemap to use for the map. Default is
+ ipyleaflet.basemaps.OpenStreetMap.Mapnik.
+
+ """
+
+ def __init__(
+ self,
+ basemap=ipyleaflet.basemaps.OpenStreetMap.Mapnik,
+ ):
+ self.map = ipyleaflet.Map(
+ basemap=basemap, layout=ipywidgets.Layout(width="100%", height="600px")
+ )
+ self.legend_widget = ipywidgets.Output()
+
+ def _set_boundaries(self, stations_metadata, coord_names):
+ """
+ Compute the boundaries of the map based on the stations metadata.
+ """
+ lon_column = coord_names[0]
+ lat_column = coord_names[1]
+
+ lons = stations_metadata[lon_column].values
+ lats = stations_metadata[lat_column].values
+
+ min_lat, max_lat = min(lats), max(lats)
+ min_lon, max_lon = min(lons), max(lons)
+
+ bounds = [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))]
+ self.map.fit_bounds(bounds)
+
+ def add_geolayer(self, geodata, colormap, widgets, coord_names=None):
+ """
+ Add a geolayer to the map.
+
+ Parameters
+ ----------
+ geodata : geopandas.GeoDataFrame
+ The geodataframe containing the geospatial data.
+ colormap : hat.PyleafletColormap
+ The colormap to use for the geolayer.
+ widgets : hat.WidgetsManager
+ The widgets to use for the geolayer.
+ coord_names : list of str, optional
+ The names of the columns containing the spatial coordinates.
+ Default is None.
+ """
+ geojson = ipyleaflet.GeoJSON(
+ data=json.loads(geodata.to_json()),
+ style={
+ "radius": 7,
+ "opacity": 0.5,
+ "weight": 1.9,
+ "dashArray": "2",
+ "fillOpacity": 0.5,
+ },
+ hover_style={"radius": 10, "fillOpacity": 1},
+ point_style={"radius": 5},
+ style_callback=colormap.style_callback(),
+ )
+ geojson.on_click(widgets.update)
+ self.map.add(geojson)
+
+ if coord_names is not None:
+ self._set_boundaries(geodata, coord_names)
+
+ self.legend_widget = colormap.legend()
+
+ def output(self, layout={}):
+ """
+ Return the output widget.
+
+ Parameters
+ ----------
+ layout : ipywidgets.Layout
+ The layout of the widget.
+
+ Returns
+ -------
+ ipywidgets.VBox
+ The output widget.
+
+ """
+ output = ipywidgets.VBox([self.map, self.legend_widget], layout=layout)
+ return output
+
+
+class PyleafletColormap:
+ """
+ A class handling the colormap of a pyleaflet map.
+
+ Parameters
+ ----------
+ config : dict
+ A dictionary containing configuration options for the map.
+ stats : xarray.Dataset or None, optional
+ A dataset containing the data to be plotted on the map.
+ If None, a default constant colormap will be used.
+ colormap_style : str, optional
+ The name of the matplotlib colormap to use. Default is 'viridis'.
+ range : tuple of float, optional
+ The minimum and maximum values of the colormap. If None, the
+ minimum and maximum values in `stats` will be used.
+ """
+
+ def __init__(
+ self,
+ config={},
+ stats=None,
+ colormap_style="viridis",
+ range=None,
+ empty_color="white",
+ default_color="blue",
+ ):
+ self.config = config
+ self.stats = stats
+ self.empty_color = empty_color
+ self.default_color = default_color
+ if self.stats is not None:
+ assert (
+ "station_id_column_name" in self.config
+ ), 'Config must contain "station_id_column_name"'
+ # Normalize the data for coloring
+ if range is None:
+ self.min_val = self.stats.values.min()
+ self.max_val = self.stats.values.max()
+ else:
+ self.min_val = range[0]
+ self.max_val = range[1]
+ else:
+ self.min_val = 0
+ self.max_val = 1
+
+ try:
+ self.colormap = mpl.colormaps[colormap_style]
+ except KeyError:
+ raise KeyError(
+ f"Colormap {colormap_style} not found. "
+ f"Available colormaps are: {mpl.colormaps}"
+ )
+
+ def style_callback(self):
+ """
+ Returns a function that can be used as input style for the ipyleaflet
+ layer.
+
+ Returns
+ -------
+ function
+ A function that takes a dataframe feature as input and returns a
+ dictionary of style options for the ipyleaflet layer.
+ """
+ if self.stats is not None:
+ norm = plt.Normalize(self.min_val, self.max_val)
+
+ def map_color(feature):
+ station_id = feature["properties"][
+ self.config["station_id_column_name"]
+ ]
+ if station_id in self.stats.station.values:
+ station_stats = self.stats.sel(station=station_id)
+ color = mpl.colors.rgb2hex(
+ self.colormap(norm(station_stats.values))
+ )
+ else:
+ color = self.empty_color
+ style = {
+ "color": "black",
+ "fillColor": color,
+ }
+ return style
+
+ else:
+
+ def map_color(feature):
+ return {
+ "color": "black",
+ "fillColor": self.default_color,
+ }
+
+ return map_color
+
+ def legend(self):
+ """
+ Generates an HTML legend for the colormap.
+
+ Returns
+ -------
+ ipywidgets.HTML
+ An HTML widget containing the colormap legend.
+ """
+ # Convert the colormap to a list of RGB values
+ rgb_values = [
+ mpl.colors.rgb2hex(self.colormap(i)) for i in np.linspace(0, 1, 256)
+ ]
+
+ # Create a gradient style using the RGB values
+ gradient_style = ", ".join(rgb_values)
+ gradient_html = f"""
+
+ """
+
+ # Create labels
+ labels_html = f"""
+
+ Low: {self.min_val:.1f}
+ High: {self.max_val:.1f}
+
+ """
+ # Combine gradient and labels
+ legend_html = gradient_html + labels_html
+
+ return ipywidgets.HTML(legend_html)
diff --git a/hat/interactive/widgets.py b/hat/interactive/widgets.py
new file mode 100644
index 0000000..ba7c5d0
--- /dev/null
+++ b/hat/interactive/widgets.py
@@ -0,0 +1,495 @@
+import time
+
+import numpy as np
+import pandas as pd
+import plotly.graph_objs as go
+from IPython.display import clear_output, display
+from ipywidgets import HTML, DatePicker, HBox, Label, Layout, Output, VBox
+
+
+class ThrottledClick:
+ """
+ Initialize a click throttler with a given delay.
+
+ Parameters
+ ----------
+ delay : float, optional
+ The delay in seconds between clicks. Defaults to 1.0.
+
+ Notes
+ -----
+ This class is used to prevent users from rapidly clicking a button or widget
+ multiple times, which can cause the application to crash or behave unexpectedly.
+
+ Examples
+ --------
+ >>> click_throttler = ThrottledClick(delay=0.5)
+ >>> if click_throttler.should_process():
+ ... # do something
+ """
+
+ def __init__(self, delay=1.0):
+ self.delay = delay
+ self.last_call = 0
+
+ def should_process(self):
+ """
+ Determine if a click should be processed based on the delay.
+
+ Returns
+ -------
+ bool
+ True if the click should be processed, False otherwise.
+
+ Notes
+ -----
+ This method should be called before processing a click event. If the
+ time since the last click is greater than the delay, the method returns
+ True and updates the last_call attribute. Otherwise, it returns False.
+ """
+ current_time = time.time()
+ if current_time - self.last_call > self.delay:
+ self.last_call = current_time
+ return True
+ return False
+
+
+class WidgetsManager:
+ """
+ A class for managing a collection of widgets and updating following
+ a user interaction, providing an index.
+
+ Parameters
+ ----------
+ widgets : dict
+ A dictionary of widgets to manage.
+ index_column : str
+ The name of the column containing the index used to update the widgets.
+ loading_widget : optional
+ A widget to display a loading message while data is being loaded.
+
+ Attributes
+ ----------
+ widgets : dict
+ A dictionary of widgets being managed.
+ index_column : str
+ The name of the column containing the index used to update the widgets.
+ throttler : ThrottledClick
+ A throttler for click events.
+ loading_widget : optional
+ A widget to display a loading message while data is being loaded.
+ """
+
+ def __init__(self, widgets, index_column, loading_widget=None):
+ self.widgets = widgets
+ self.index_column = index_column
+ self.throttler = ThrottledClick()
+ self.loading_widget = loading_widget
+
+ def __getitem__(self, item):
+ return self.widgets[item]
+
+ def update(self, feature, **kwargs):
+ """
+ Handle the selection of a marker on the map.
+
+ Parameters
+ ----------
+ feature : dict
+ A dictionary containing information about the selected feature.
+ **kwargs : dict
+ Additional keyword arguments to pass to the widgets update method.
+ """
+
+ # Check if we should process the click
+ if not self.throttler.should_process():
+ return
+
+ if self.loading_widget is not None:
+ self.loading_widget.value = (
+ "Loading..." # Indicate that data is being loaded
+ )
+
+ # Extract station_id from the selected feature
+ metadata = feature["properties"]
+ index = metadata[self.index_column]
+
+ # update widgets
+ for wgt in self.widgets.values():
+ wgt.update(index, metadata, **kwargs)
+
+ if self.loading_widget is not None:
+ self.loading_widget.value = "" # Clear the loading message
+
+
+class Widget:
+ """
+ A base class for interactive widgets.
+
+ Parameters
+ ----------
+ output : Output
+ The ipywidget compatible object to display the widget's content.
+
+ Attributes
+ ----------
+ output : Output
+ The ipywidget compatible object to display the widget's content.
+
+ Methods
+ -------
+ update(index, metadata, **kwargs)
+ Update the widget's content based on the given index and metadata.
+ """
+
+ def __init__(self, output):
+ self.output = output
+
+ def update(self, index, *args, **kwargs):
+ raise NotImplementedError
+
+
+def _filter_nan_values(dates, data_values):
+ """
+ Filters out NaN values and their associated dates.
+ """
+ assert len(dates) == len(
+ data_values
+ ), "Dates and data values must be the same length."
+ valid_dates = [date for date, val in zip(dates, data_values) if not np.isnan(val)]
+ valid_data = [val for val in data_values if not np.isnan(val)]
+
+ return valid_dates, valid_data
+
+
+class PlotlyWidget(Widget):
+ """
+ A widget to display timeseries data using Plotly.
+
+ Parameters
+ ----------
+ datasets : dict
+ A dictionary containing the xarray timeseries datasets to be displayed.
+
+ Attributes
+ ----------
+ datasets : dict
+ A dictionary containing the xarray timeseries datasets to be displayed.
+ figure : plotly.graph_objs._figurewidget.FigureWidget
+ The Plotly figure widget.
+ ds_time_str : list
+ A list of strings representing the dates in the timeseries data.
+ start_date_picker : DatePicker
+ The date picker widget for selecting the start date.
+ end_date_picker : DatePicker
+ The date picker widget for selecting the end date.
+ """
+
+ def __init__(self, datasets):
+ self.datasets = datasets
+ self.figure = go.FigureWidget(
+ layout=go.Layout(
+ height=350,
+ margin=dict(l=120),
+ legend=dict(
+ orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
+ ),
+ xaxis_title="Date",
+ xaxis_tickformat="%d-%m-%Y",
+ yaxis_title="Discharge [m3/s]",
+ )
+ )
+ ds_time = datasets["obs"]["time"].values.astype("datetime64[D]")
+ self.ds_time_str = [dt.isoformat() for dt in pd.to_datetime(ds_time)]
+
+ self.start_date_picker = DatePicker(description="Start")
+ self.end_date_picker = DatePicker(description="End")
+
+ self.start_date_picker.observe(self._update_plot_dates, names="value")
+ self.end_date_picker.observe(self._update_plot_dates, names="value")
+
+ date_label = Label(
+ "Please select the date to accurately change the date axis of the plot"
+ )
+ date_picker_box = HBox([self.start_date_picker, self.end_date_picker])
+
+ layout = Layout(justify_content="center", align_items="center")
+ output = VBox([self.figure, date_label, date_picker_box], layout=layout)
+ super().__init__(output)
+
+ def _update_plot_dates(self):
+ """
+ Updates the plot with the selected start and end dates.
+ """
+ start_date = self.start_date_picker.value.strftime("%Y-%m-%d")
+ end_date = self.end_date_picker.value.strftime("%Y-%m-%d")
+ self.figure.update_layout(xaxis_range=[start_date, end_date])
+
+ def _update_data(self, station_id):
+ """
+ Updates the simulation data for the given station ID.
+ """
+ for name, ds in self.datasets.items():
+ if station_id in ds["station"].values:
+ ds_time_series_data = ds.sel(station=station_id).values
+ valid_dates_ds, valid_data_ds = _filter_nan_values(
+ self.ds_time_str, ds_time_series_data
+ )
+ self._update_trace(valid_dates_ds, valid_data_ds, name)
+ else:
+ print(f"Station ID: {station_id} not found in dataset {name}.")
+ return False
+ return True
+
+ def _update_trace(self, x_data, y_data, name):
+ """
+ Updates the plot trace for the given name with the given x and y data.
+ """
+ trace_exists = any([trace.name == name for trace in self.figure.data])
+ if trace_exists:
+ for trace in self.figure.data:
+ if trace.name == name:
+ trace.x = x_data
+ trace.y = y_data
+ else:
+ self.figure.add_trace(
+ go.Scatter(x=x_data, y=y_data, mode="lines", name=name)
+ )
+
+ def _update_title(self, metadata):
+ """
+ Updates the plot title following the point metadata.
+ """
+ station_id = metadata["station_id"]
+ station_name = metadata["StationName"]
+ updated_title = (
+ f"Selected station:
ID: {station_id}, name: {station_name} "
+ )
+ self.figure.update_layout(
+ title={
+ "text": updated_title,
+ "y": 0.9,
+ "x": 0.5,
+ "xanchor": "center",
+ "yanchor": "top",
+ "font": {"color": "black", "size": 16},
+ }
+ )
+
+ def update(self, index, *args, **kwargs):
+ """
+ Updates the overall plot with new data for the given index.
+
+ Parameters
+ ----------
+ index : str
+ The ID of the station to update the data for.
+ metadata : dict
+ A dictionary containing the metadata for the selected station.
+ """
+ return self._update_data(index)
+
+
+class HTMLTableWidget(Widget):
+ """
+ A widget to display a pandas dataframe with the HTML format.
+
+ Parameters
+ ----------
+ title : str
+ The title of the table.
+ """
+
+ def __init__(self, title):
+ self.title = title
+ super().__init__(Output())
+
+ # Define the styles for the statistics table
+ self.table_style = """
+
+ """
+ self.stat_title_style = (
+ "style='font-size: 18px; font-weight: bold; text-align: center;'"
+ )
+ # Initialize the stat_table_html and station_table_html with empty tables
+ empty_df = pd.DataFrame()
+ self._display_dataframe_with_scroll(empty_df, title=self.title)
+
+ def _display_dataframe_with_scroll(self, df, title=""):
+ table_html = df.to_html(classes="custom-table")
+ content = f"{self.table_style}{title}
{table_html}" # noqa: E501
+ with self.output:
+ clear_output(wait=True) # Clear any previous plots or messages
+ display(HTML(content))
+
+ def update(self, index, *args, **kwargs):
+ """
+ Update the table with the dataframe as the given index.
+
+ Parameters
+ ----------
+ index : int
+ The index of the data to be displayed.
+ metadata : dict
+ The metadata associated with the data index.
+ """
+ dataframe = self._extract_dataframe(index)
+ self._display_dataframe_with_scroll(dataframe, title=self.title)
+ if dataframe.empty:
+ return False
+ return True
+
+
+class DataFrameWidget(Widget):
+ """
+ A widget to display a pandas dataframe with the default pandas display
+ style.
+
+ Parameters
+ ----------
+ title : str
+ The title of the table.
+ """
+
+ def __init__(self, title):
+ self.title = title
+ super().__init__(output=Output(title=self.title))
+
+ # Initialize the stat_table_html and station_table_html with empty tables
+ empty_df = pd.DataFrame()
+ with self.output:
+ clear_output(wait=True) # Clear any previous plots or messages
+ display(empty_df)
+
+ def update(self, index, *args, **kwargs):
+ """
+ Update the table with the dataframe as the given index.
+
+ Parameters
+ ----------
+ index : int
+ The index of the data to be displayed.
+ metadata : dict
+ The metadata associated with the data index.
+ """
+ dataframe = self._extract_dataframe(index)
+ with self.output:
+ clear_output(wait=True) # Clear any previous plots or messages
+ display(dataframe)
+ if dataframe.empty:
+ return False
+ return True
+
+ def _extract_dataframe(self, index):
+ """
+ Virtual method to return the object dataframe at the index.
+ """
+ raise NotImplementedError
+
+
+class MetaDataWidget(HTMLTableWidget):
+ """
+ An extension of the HTMLTableWidget class to display a station metadata.
+
+ Parameters
+ ----------
+ dataframe : pd.DataFrame
+ A pandas dataframe to be displayed in the table.
+ station_index : str
+ Column name of the station index.
+ """
+
+ def __init__(self, dataframe, station_index):
+ title = "Station Metadata"
+ self.dataframe = dataframe
+ self.station_index = station_index
+ super().__init__(title)
+
+ def _extract_dataframe(self, station_id):
+ stations_df = self.dataframe
+ selected_station_df = stations_df[stations_df[self.station_index] == station_id]
+ return selected_station_df
+
+
+class StatisticsWidget(HTMLTableWidget):
+ """
+ An extension of the HTMLTableWidget to display statistics at stations.
+
+ Parameters
+ ----------
+ dataframe : pd.DataFrame
+ A pandas dataframe to be displayed in the table.
+ station_index : str
+ Column name of the station index.
+ """
+
+ def __init__(self, statistics):
+ title = "Model Performance Statistics Overview"
+ self.statistics = statistics
+ for stat in self.statistics.values():
+ assert (
+ "station" in stat.dims
+ ), 'Dimension "station" not found in statistics datasets.' # noqa: E501
+ super().__init__(title)
+
+ def _extract_dataframe(self, station_id):
+ """Generate a statistics table for the given station ID."""
+ data = []
+
+ # Check if statistics is None or empty
+ if not self.statistics:
+ print("No statistics data provided.")
+ return pd.DataFrame() # Return an empty dataframe
+
+ # Loop through each simulation and get the statistics for the given station_id
+ for exp_name, stats in self.statistics.items():
+ if station_id in stats["station"].values:
+ row = [exp_name] + [
+ round(stats[var].sel(station=station_id).values.item(), 2)
+ for var in stats.data_vars
+ if var not in ["longitude", "latitude"]
+ ]
+ data.append(row)
+
+ # Check if data has any items
+ if not data:
+ print(f"No statistics data found for station ID: {station_id}.")
+ return pd.DataFrame() # Return an empty dataframe
+
+ # Convert the data to a DataFrame for display
+ columns = ["Exp. name"] + list(stats.data_vars.keys())
+ statistics_df = pd.DataFrame(data, columns=columns)
+
+ # Round the numerical columns to 2 decimal places
+ numerical_columns = [col for col in statistics_df.columns if col != "Exp. name"]
+ statistics_df[numerical_columns] = statistics_df[numerical_columns].round(2)
+
+ return statistics_df
diff --git a/hat/networking.py b/hat/networking.py
deleted file mode 100644
index a5b04ab..0000000
--- a/hat/networking.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import os
-import platform
-import socket
-
-
-def get_host():
- """Local host on Mac and network host on HPC
- (note the network address on HPC is not constant)"""
-
- # get the hostname
- hostname = socket.gethostname()
-
- # get the IP address(es) associated with the hostname
- ip_addresses = socket.getaddrinfo(
- hostname, None, socket.AF_INET, socket.SOCK_STREAM
- )
-
- # return first valid address
- for ip_address in ip_addresses:
- network_host = ip_address[4][0]
- return network_host
-
-
-def mac_or_hpc():
- """Is this running on a Mac or the HPC or other?"""
-
- if platform.system() == "Darwin":
- return "mac"
- elif platform.system() == "Linux" and os.environ.get("ECPLATFORM"):
- return "hpc"
- else:
- return "other"
-
-
-def host_and_port(host="127.0.0.1", port=8000):
- """return network host and port for tiler app to use"""
-
- computer = mac_or_hpc()
-
- if computer == "hpc":
- host = get_host()
- port = 8700
-
- return (host, port)
diff --git a/hat/observations.py b/hat/observations.py
index f895201..467eb70 100644
--- a/hat/observations.py
+++ b/hat/observations.py
@@ -37,9 +37,6 @@ def read_station_metadata_file(
"""read hydrological stations from file. will cache as pickle object
because .csv file used by the team takes 12 seconds to load"""
- print("station file")
- print(fpath)
-
try:
if is_csv(fpath):
gdf = read_csv_and_cache(fpath)
diff --git a/hat/parsers.py b/hat/parsers.py
deleted file mode 100644
index 0c50b64..0000000
--- a/hat/parsers.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import dateutil.parser
-import streamlit as st
-
-
-@st.cache_data
-def datetime_from_cftime(cftimes):
- """parse CFTimeIndex to python datetime,
- e.g. from a NetCDF file ds.indexes['time']"""
- return [dateutil.parser.parse(x.isoformat()) for x in cftimes]
-
-
-def simulation_timeperiod(sim):
- # simulation timeperiod
- min_time = min(sim.indexes["time"])
- max_time = max(sim.indexes["time"])
-
- return (min_time, max_time)
diff --git a/hat/plots.py b/hat/plots.py
deleted file mode 100644
index 3581409..0000000
--- a/hat/plots.py
+++ /dev/null
@@ -1,72 +0,0 @@
-from typing import Union
-
-import numpy as np
-import pandas as pd
-import plotly.express as px
-import streamlit as st
-from matplotlib import pyplot as plt
-
-
-# PLOTLY (interactive)
-def plotly_timeseries(t, y):
- df = pd.DataFrame({"time": t, "discharge": y})
- return px.line(df, x="time", y="discharge", title="Discharge Timeseries")
-
-
-# MATPLOTLIB (not interactive)
-def plot_timeseries(t, y, jupyter=False):
- fig, ax1 = plt.subplots()
- fig.set_size_inches(14, 6)
- ax1.plot(t, y, "dodgerblue")
- ax1.set_xlabel("time (s)")
- ax1.set_ylabel("discharge", color="b")
- ax1.tick_params("y", colors="b")
-
- if jupyter:
- return fig
-
- st.write(fig)
-
-
-def histogram(
- arr: Union[np.array, np.ma.MaskedArray],
- bins=10,
- clip=None,
- title="Histogram",
- figsize=(6, 4),
-):
- """plot histogram of a numpy array or masked numpy array"""
-
- # apply mask (if one exists)
- if isinstance(arr, np.ma.MaskedArray):
- arr = arr.compressed()
-
- # return if not numpy
- if not isinstance(arr, np.ndarray):
- print("histogram() requires a numpy array or masked numpy array")
- return
-
- # remove flat dimensions
- arr = arr.squeeze()
-
- # remove nans
- arr = arr[~np.isnan(arr)]
-
- # histogram range (percentile clip or minmax)
- if clip:
- histogram_range = (
- round(np.percentile(arr, clip)),
- round(np.percentile(arr, 100 - clip)),
- )
- else:
- histogram_range = (np.min(arr), np.max(arr))
-
- # count number of values in each bin
- counts, bins = np.histogram(arr, bins=bins, range=histogram_range)
-
- _ = plt.figure(figsize=figsize)
- plt.hist(bins[:-1], bins, weights=counts)
- plt.title(title)
-
- # show plot
- plt.show()
diff --git a/hat/tools/hydrostats_cli.py b/hat/tools/hydrostats_cli.py
index 5019b9c..204f92a 100644
--- a/hat/tools/hydrostats_cli.py
+++ b/hat/tools/hydrostats_cli.py
@@ -4,6 +4,7 @@
import xarray as xr
from hat import hydrostats_functions
+from hat.data import find_main_var
from hat.exceptions import UserError
from hat.filters import filter_timeseries
from hat.hydrostats import run_analysis
@@ -81,15 +82,19 @@ def hydrostats_cli(
# simulations
sims_ds = xr.open_dataset(sims)
+ var = find_main_var(sims_ds, min_dim=2)
+ sims_da = sims_ds[var]
# observations
obs_ds = xr.open_dataset(obs)
+ var = find_main_var(obs_ds, min_dim=2)
+ obs_da = obs_ds[var]
# clean timeseries
- sims_ds, obs = filter_timeseries(sims_ds, obs_ds, threshold=obs_threshold)
+ sims_da, obs_da = filter_timeseries(sims_da, obs_da, threshold=obs_threshold)
# calculate statistics
- statistics_ds = run_analysis(functions, sims_ds, obs_ds)
+ statistics_ds = run_analysis(functions, sims_da, obs_da)
# save to netcdf
statistics_ds.to_netcdf(outpath)
diff --git a/hat/version.py b/hat/version.py
index 3d18726..906d362 100644
--- a/hat/version.py
+++ b/hat/version.py
@@ -1 +1 @@
-__version__ = "0.5.0"
+__version__ = "0.6.0"
diff --git a/notebooks/examples/4_visualisation_interactive.ipynb b/notebooks/examples/4_visualisation_interactive.ipynb
new file mode 100644
index 0000000..2194b6d
--- /dev/null
+++ b/notebooks/examples/4_visualisation_interactive.ipynb
@@ -0,0 +1,131 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Import necessart visualisation library and define inputs simulation and additional vector to add"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "config = {\n",
+ " \"stations\": \"~/git/hat/data/outlets_v4.0_20230726_withEFAS.csv\",\n",
+ " \"observations\": \"~/git/hat/data/observations/destine_observations.nc\",\n",
+ " \"simulations\": {\n",
+ " \"i05j\": \"~/git/hat/data/cama_i05j_stations.nc\",\n",
+ " \"i05h\": \"~/git/hat/data/cama_i05h_stations.nc\",\n",
+ " },\n",
+ " \"statistics\": {\n",
+ " \"i05j\": \"~/git/hat/data/observations/statistics_i05j.nc\",\n",
+ " \"i05h\": \"~/git/hat/data/observations/statistics_i05h.nc\",\n",
+ " },\n",
+ " \"station_epsg\": 4326,\n",
+ " \"station_id_column_name\": \"station_id\",\n",
+ " \"station_filters\":\"\",\n",
+ " \"station_coordinates\": [\"StationLon\", \"StationLat\"],\n",
+ " \"obs_var_name\": \"obsdis\",\n",
+ " \"sims_var_name\": \"dis\",\n",
+ "}\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Initialise the map object from the configuration."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "Exception",
+ "evalue": "Could not open file ~/git/hat/data/outlets_v4.0_20230726_withEFAS.csv",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mCPLE_OpenFailedError\u001b[0m Traceback (most recent call last)",
+ "File \u001b[0;32mfiona/_shim.pyx:83\u001b[0m, in \u001b[0;36mfiona._shim.gdal_open_vector\u001b[0;34m()\u001b[0m\n",
+ "File \u001b[0;32mfiona/_err.pyx:291\u001b[0m, in \u001b[0;36mfiona._err.exc_wrap_pointer\u001b[0;34m()\u001b[0m\n",
+ "\u001b[0;31mCPLE_OpenFailedError\u001b[0m: /home/macw/git/hat/data/outlets_v4.0_20230726_withEFAS.csv: No such file or directory",
+ "\nDuring handling of the above exception, another exception occurred:\n",
+ "\u001b[0;31mDriverError\u001b[0m Traceback (most recent call last)",
+ "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/observations.py:42\u001b[0m, in \u001b[0;36mread_station_metadata_file\u001b[0;34m(fpath, coord_names, epsg, filters)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_csv(fpath):\n\u001b[0;32m---> 42\u001b[0m gdf \u001b[38;5;241m=\u001b[39m \u001b[43mread_csv_and_cache\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m gdf \u001b[38;5;241m=\u001b[39m add_geometry_column(gdf, coord_names)\n",
+ "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/data.py:234\u001b[0m, in \u001b[0;36mread_csv_and_cache\u001b[0;34m(fpath)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;66;03m# otherwise load from user defined filepath (and then cache)\u001b[39;00m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 234\u001b[0m gdf \u001b[38;5;241m=\u001b[39m \u001b[43mgpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 235\u001b[0m gdf\u001b[38;5;241m.\u001b[39mto_pickle(cache_fpath)\n",
+ "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/geopandas/io/file.py:259\u001b[0m, in \u001b[0;36m_read_file\u001b[0;34m(filename, bbox, mask, rows, engine, **kwargs)\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m engine \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfiona\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 259\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read_file_fiona\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 260\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath_or_bytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfrom_bytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbbox\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbbox\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrows\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrows\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m engine \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpyogrio\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
+ "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/geopandas/io/file.py:303\u001b[0m, in \u001b[0;36m_read_file_fiona\u001b[0;34m(path_or_bytes, from_bytes, bbox, mask, rows, where, **kwargs)\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m fiona_env():\n\u001b[0;32m--> 303\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mreader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_or_bytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m features:\n\u001b[1;32m 304\u001b[0m crs \u001b[38;5;241m=\u001b[39m features\u001b[38;5;241m.\u001b[39mcrs_wkt\n",
+ "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/fiona/env.py:408\u001b[0m, in \u001b[0;36mensure_env_with_credentials..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m local\u001b[38;5;241m.\u001b[39m_env:\n\u001b[0;32m--> 408\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
+ "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/fiona/__init__.py:264\u001b[0m, in \u001b[0;36mopen\u001b[0;34m(fp, mode, driver, schema, crs, encoding, layer, vfs, enabled_drivers, crs_wkt, **kwargs)\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mode \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[0;32m--> 264\u001b[0m c \u001b[38;5;241m=\u001b[39m \u001b[43mCollection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdriver\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdriver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43menabled_drivers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menabled_drivers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m:\n",
+ "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/fiona/collection.py:162\u001b[0m, in \u001b[0;36mCollection.__init__\u001b[0;34m(self, path, mode, driver, schema, crs, encoding, layer, vsi, archive, enabled_drivers, crs_wkt, ignore_fields, ignore_geometry, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msession \u001b[38;5;241m=\u001b[39m Session()\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstart\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmode \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m):\n",
+ "File \u001b[0;32mfiona/ogrext.pyx:540\u001b[0m, in \u001b[0;36mfiona.ogrext.Session.start\u001b[0;34m()\u001b[0m\n",
+ "File \u001b[0;32mfiona/_shim.pyx:90\u001b[0m, in \u001b[0;36mfiona._shim.gdal_open_vector\u001b[0;34m()\u001b[0m\n",
+ "\u001b[0;31mDriverError\u001b[0m: /home/macw/git/hat/data/outlets_v4.0_20230726_withEFAS.csv: No such file or directory",
+ "\nDuring handling of the above exception, another exception occurred:\n",
+ "\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mhat\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minteractive\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexplorers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TimeSeriesExplorer\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28mmap\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mTimeSeriesExplorer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/interactive/explorers.py:211\u001b[0m, in \u001b[0;36mTimeSeriesExplorer.__init__\u001b[0;34m(self, config)\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;124;03mInitializes an instance of the Explorer class.\u001b[39;00m\n\u001b[1;32m 160\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 207\u001b[0m \n\u001b[1;32m 208\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig \u001b[38;5;241m=\u001b[39m config\n\u001b[0;32m--> 211\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstations_metadata \u001b[38;5;241m=\u001b[39m \u001b[43mread_station_metadata_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 212\u001b[0m \u001b[43m \u001b[49m\u001b[43mfpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstations\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[43m \u001b[49m\u001b[43mcoord_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstation_coordinates\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[43m \u001b[49m\u001b[43mepsg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstation_epsg\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 215\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstation_filters\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 216\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;66;03m# Use the external functions to prepare data\u001b[39;00m\n\u001b[1;32m 219\u001b[0m sim_ds \u001b[38;5;241m=\u001b[39m prepare_simulations_data(\n\u001b[1;32m 220\u001b[0m config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msimulations\u001b[39m\u001b[38;5;124m\"\u001b[39m], config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msims_var_name\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 221\u001b[0m )\n",
+ "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/observations.py:48\u001b[0m, in \u001b[0;36mread_station_metadata_file\u001b[0;34m(fpath, coord_names, epsg, filters)\u001b[0m\n\u001b[1;32m 46\u001b[0m gdf \u001b[38;5;241m=\u001b[39m gpd\u001b[38;5;241m.\u001b[39mread_file(fpath)\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not open file \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;66;03m# (optionally) filter the stations, e.g. 'Contintent == Europe'\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m filters \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
+ "\u001b[0;31mException\u001b[0m: Could not open file ~/git/hat/data/outlets_v4.0_20230726_withEFAS.csv"
+ ]
+ }
+ ],
+ "source": [
+ "from hat.interactive.explorers import TimeSeriesExplorer\n",
+ "map = TimeSeriesExplorer(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Display map and use the tool to show plot of the simulation time series."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "map.mapplot(colorby='kge', sim='i05j')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "hat-venv",
+ "language": "python",
+ "name": "hat-venv"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/examples/4_visualisation_interactive_template.ipynb b/notebooks/examples/4_visualisation_interactive_template.ipynb
new file mode 100644
index 0000000..544522f
--- /dev/null
+++ b/notebooks/examples/4_visualisation_interactive_template.ipynb
@@ -0,0 +1,119 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "1. Import visualisation library and define inputs (stations, observations, simulation, statistics, and also the config that may differ depending on the simulation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from hat.visualisation import InteractiveMap\n",
+ "\n",
+ "stations= '~/destinE/outlets_v4.0_20230726_withEFAS.csv'\n",
+ "observations= '~/destinE/destine_observation.nc'\n",
+ "simulations = {\n",
+ " \"i05j\": \"~/destinE/cama_i05j_stations.nc\",\n",
+ " \"i05h\": \"~/destinE/cama_i05h_stations.nc\"\n",
+ "}\n",
+ "# xarray input\n",
+ "\n",
+ "statistics = {\n",
+ " \"i05j\": \"~/destinE/statistics_i05j.nc\",\n",
+ " \"i05h\": \"~/destinE/statistics_i05h.nc\"\n",
+ "}\n",
+ "# xarray input\n",
+ "\n",
+ "config = {\n",
+ " \"station_epsg\": 4326,\n",
+ " \"station_id_column_name\": \"station_id\",\n",
+ " \"station_filters\":\"\",\n",
+ " \"station_coordinates\": [\"StationLon\", \"StationLat\"],\n",
+ " \"obs_var_name\": \"obsdis\"\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "2. Initialise & processing all the input data into the map object"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found 1614 common stations\n"
+ ]
+ }
+ ],
+ "source": [
+ "map = InteractiveMap(config, stations, observations, simulations, stats=statistics)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "3. Display map interface for interactive viewing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "44784a504c27442990df5eb6a5f26a72",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(Label(value='Interactive Map Visualisation for Hydrological Model Performance', layout=Layout(j…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "map.mapplot(colorby='kge', sim='i05h', range=[-1, 1]) "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "hat-dev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tests/test_hydrostats_decorators.py b/tests/test_hydrostats_decorators.py
index 3a53a38..c3839aa 100644
--- a/tests/test_hydrostats_decorators.py
+++ b/tests/test_hydrostats_decorators.py
@@ -63,10 +63,8 @@ def test_filter_nan():
assert np.allclose(decorated(nan3, nan3), np.array([2, 4]))
# all nans
- with pytest.raises(ValueError):
- decorated(arr, nans)
- with pytest.raises(ValueError):
- decorated(nans, arr)
+ assert decorated(arr, nans) is None
+ assert decorated(nans, arr) is None
# def test_handle_divide_by_zero_error():
@@ -124,8 +122,7 @@ def test_hydrostat():
# # all zero division
# with pytest.raises(ZeroDivisionError):
- # decorated_divide(ones, zeros)
+ # print(decorated_divide(ones, zeros))
# all nans
- with pytest.raises(ValueError):
- decorated_divide(nans, nans)
+ assert decorated_divide(nans, nans) is None
diff --git a/tests/test_interactive.py b/tests/test_interactive.py
new file mode 100644
index 0000000..de24d32
--- /dev/null
+++ b/tests/test_interactive.py
@@ -0,0 +1,215 @@
+import time
+
+import geopandas as gpd
+import numpy as np
+import pandas as pd
+import pytest
+import xarray as xr
+from shapely.geometry import Point
+
+from hat.interactive import leaflet as lf
+from hat.interactive import widgets as wd
+
+
+class TestThrottledClick:
+ def test_delay(self):
+ throttler = wd.ThrottledClick(0.01)
+ assert throttler.should_process() is True
+ assert throttler.should_process() is False
+ time.sleep(0.015)
+ assert throttler.should_process() is True
+
+
+class DummyWidget(wd.Widget):
+ def __init__(self):
+ super().__init__(output=None)
+ self.index = None
+
+ def update(self, index, metadata, **kwargs):
+ self.index = index
+
+
+class TestWidgetsManager:
+ def test_update(self):
+ dummy = DummyWidget()
+ widgets = wd.WidgetsManager(widgets={"dummy": dummy}, index_column="station")
+ feature = {
+ "properties": {
+ "station": "A",
+ }
+ }
+ widgets.update(feature)
+ assert dummy.index == "A"
+
+
+def test_filter_nan():
+ dates = np.arange(10)
+ values = np.arange(10, dtype=float)
+ values[5] = np.nan
+ dates_new, values_new = wd._filter_nan_values(dates, values)
+ assert len(dates_new) == 9
+ assert len(values_new) == 9
+
+
+class TestPlotlyWidget:
+ def test_update(self):
+ datasets = {
+ "obs": xr.DataArray(
+ [[0, 3, 6], [0, 3, 6]],
+ coords={
+ "time": np.array(
+ ["2007-07-13", "2007-01-14"], dtype="datetime64[ns]"
+ ),
+ "station": [0, 1, 2],
+ },
+ ),
+ "sim1": xr.DataArray(
+ [[1, 2, 3], [1, 2, 3]],
+ coords={
+ "time": np.array(
+ ["2007-07-13", "2007-01-14"], dtype="datetime64[ns]"
+ ),
+ "station": [0, 1, 2],
+ },
+ ),
+ "sim2": xr.DataArray(
+ [[4, 5, 6], [4, 5, 6]],
+ coords={
+ "time": np.array(
+ ["2007-07-13", "2007-01-14"], dtype="datetime64[ns]"
+ ),
+ "station": [0, 1, 2],
+ },
+ ),
+ }
+ widget = wd.PlotlyWidget(
+ datasets=datasets,
+ )
+ assert widget.update(1) is True # if id found
+ assert widget.update(3) is False # if id not found
+
+
+class TestMetaDataWidget:
+ def test_update(self):
+ df = pd.DataFrame(
+ {"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [0.1, 0.2, 0.3]}
+ )
+ widget = wd.MetaDataWidget(df, "col2")
+ assert widget.update("a") is True
+ assert widget.update("f") is False
+
+
+class TestStatisticsWidget:
+ def test_update(self):
+ datasets = {
+ "sim1": xr.Dataset(
+ {"data": [0, 1, 2]},
+ coords={
+ "station": [0, 1, 2],
+ },
+ ),
+ "sim2": xr.Dataset(
+ {"data": [4, 5, 6]},
+ coords={
+ "station": [0, 1, 2],
+ },
+ ),
+ }
+ widget = wd.StatisticsWidget(datasets)
+ assert widget.update(0) is True
+ assert widget.update(4) is False
+
+ def test_fail_creation(self):
+ datasets = {
+ "sim1": xr.Dataset(
+ {"data": [0, 1, 2]},
+ coords={
+ "id": [0, 1, 2],
+ },
+ ),
+ }
+ with pytest.raises(AssertionError):
+ wd.StatisticsWidget(datasets)
+
+
+class TestPyleafletColormap:
+ config = {
+ "station_id_column_name": "station",
+ }
+ stats = xr.DataArray(
+ [0, 1, 2],
+ coords={
+ "station": [0, 1, 2],
+ },
+ )
+
+ def test_default_creation(self):
+ lf.PyleafletColormap()
+
+ def test_creation_with_stats(self):
+ lf.PyleafletColormap(
+ self.config, self.stats, colormap_style="plasma", range=[1, 2]
+ )
+
+ def test_creation_wrong_colormap(self):
+ with pytest.raises(KeyError):
+ lf.PyleafletColormap(
+ self.config, self.stats, colormap_style="awdawd", range=[1, 2]
+ )
+
+ def test_fail_creation(self):
+ config = {}
+ with pytest.raises(AssertionError):
+ lf.PyleafletColormap(config, self.stats)
+
+ def test_default_style(self):
+ colormap = lf.PyleafletColormap(self.config)
+ style_fct = colormap.style_callback()
+ style_fct(feature={})
+
+ def test_stats_style(self):
+ feature = {
+ "properties": {
+ "station": 2,
+ }
+ }
+ colormap = lf.PyleafletColormap(self.config, self.stats)
+ style_fct = colormap.style_callback()
+ style_fct(feature)
+
+ def test_stats_style_fail(self):
+ feature = {
+ "properties": {
+ "station": 4,
+ }
+ }
+ colormap = lf.PyleafletColormap(self.config, self.stats, empty_color="black")
+ style_fct = colormap.style_callback()
+ style = style_fct(feature)
+ assert style["fillColor"] == "black"
+
+ def test_stats_style_default(self):
+ colormap = lf.PyleafletColormap(default_color="blue")
+ style_fct = colormap.style_callback()
+ style = style_fct({})
+ assert style["fillColor"] == "blue"
+
+ def test_legend(self):
+ colormap = lf.PyleafletColormap(self.config, self.stats)
+ colormap.legend()
+
+
+class TestLeafletMap:
+ def test_creation(self):
+ lf.LeafletMap()
+
+ def test_output(self):
+ map = lf.LeafletMap()
+ map.output()
+
+ def test_add_geolayer(self):
+ map = lf.LeafletMap()
+ widgets = {}
+ colormap = lf.PyleafletColormap()
+ gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)])
+ map.add_geolayer(gdf, colormap, widgets)