diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 43a5334..17c11b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,16 +28,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - uses: conda-incubator/setup-miniconda@v2 + - uses: mamba-org/setup-micromamba@v1 with: activate-environment: test environment-file: environment.yml auto-activate-base: false - - name: conda check - shell: bash -l {0} - run: | - conda info - conda list - name: install hat package shell: bash -l {0} run: pip install . diff --git a/.github/workflows/on-push.yml b/.github/workflows/on-push.yml index c796d57..968e5fa 100644 --- a/.github/workflows/on-push.yml +++ b/.github/workflows/on-push.yml @@ -32,16 +32,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - uses: conda-incubator/setup-miniconda@v2 + - uses: mamba-org/setup-micromamba@v1 with: activate-environment: test environment-file: environment.yml auto-activate-base: false - - name: Conda check - shell: bash -l {0} - run: | - conda info - conda list - name: install hat package shell: bash -l {0} run: pip install . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30af686..5b536b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: rev: 23.3.0 hooks: - id: black - - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 +# - repo: https://github.com/PyCQA/flake8 +# rev: 6.0.0 +# hooks: +# - id: flake8 diff --git a/README.md b/README.md index 0d825dc..69cc149 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Interfaces and functionality are likely to change, and the project itself may be Clone source code repository - $ git clone git@github.com:ecmwf-projects/hat.git + $ git clone https://github.com/ecmwf/hat.git Create conda python environment @@ -63,4 +63,4 @@ does it submit to any jurisdiction. ### Citing -In publications, please use a link to this repository (https://github.com/ecmwf/hat) and its documentation (https://hydro-analysis-toolkit.readthedocs.io) \ No newline at end of file +In publications, please use a link to this repository (https://github.com/ecmwf/hat) and its documentation (https://hydro-analysis-toolkit.readthedocs.io) diff --git a/environment.yml b/environment.yml index 6b28254..1c2e5c2 100644 --- a/environment.yml +++ b/environment.yml @@ -2,21 +2,21 @@ name: hat channels: - conda-forge dependencies: - - python=3.10 + - python<=3.10 - netCDF4 - eccodes - cfgrib - - cftime - geopandas - xarray - plotly - matplotlib - - jupyterlab + - jupyter - tqdm - typer - humanize - - folium - typer + - ipyleaflet + - ipywidgets - pip # - pytest # - mkdocs diff --git a/hat/clock.py b/hat/clock.py deleted file mode 100644 index 409fa97..0000000 --- a/hat/clock.py +++ /dev/null @@ -1,71 +0,0 @@ -import datetime -import functools -import time - -import humanize - - -def digital_clock(func): - "A quiet decorator for timing functions" - - @functools.wraps(func) - def clocked(*args, **kwargs): - # time - t0 = time.perf_counter() - result = func(*args, **kwargs) - elapsed = time.perf_counter() - t0 - human_readable_time = humanize.naturaldelta(datetime.timedelta(seconds=elapsed)) - - # name - name = func.__name__ - - print(f"{name}() took {human_readable_time}") - return result - - # return function with timing decorator - return clocked - - -def clock(func): - "A verbose decorator for timing functions" - - @functools.wraps(func) - def clocked(*args, **kwargs): - # time - t0 = time.perf_counter() - result = func(*args, **kwargs) - elapsed = time.perf_counter() - t0 - human_readable_time = humanize.naturaldelta(datetime.timedelta(seconds=elapsed)) - - # name - name = func.__name__ - - # arguments - arg_str = ", ".join(repr(arg) for arg in args) - if arg_str == "": - arg_str = "no arguments" - - # keywords - pairs = [f"{k}={w}" for k, w in sorted(kwargs.items())] - key_str = ", ".join(pairs) - if key_str == "": - key_str = "no keywords" - - print( - f"""{name}() took {human_readable_time} to run - with following inputs {arg_str} and {key_str}""" - ) - return result - - # return function with timing decorator - return clocked - - -if __name__ == "__main__": - - @clock - def test(arg1, keyword=False): - time.sleep(0.1) - pass - - test(1, keyword="hello") diff --git a/hat/data.py b/hat/data.py index 4f25eee..d1ff6e0 100644 --- a/hat/data.py +++ b/hat/data.py @@ -37,6 +37,9 @@ def get_tmpdir(): else: tmpdir = TemporaryDirectory().name + # Ensure the directory exists + os.makedirs(tmpdir, exist_ok=True) + return tmpdir @@ -257,8 +260,8 @@ def save_dataset_to_netcdf(ds: xr.Dataset, fpath: str): ds.to_netcdf(fpath) -def find_main_var(ds): - variable_names = [k for k in ds.variables if len(ds.variables[k].dims) >= 3] +def find_main_var(ds, min_dim=3): + variable_names = [k for k in ds.variables if len(ds.variables[k].dims) >= min_dim] if len(variable_names) > 1: raise Exception("More than one variable in dataset") elif len(variable_names) == 0: diff --git a/hat/filters.py b/hat/filters.py index 88cdb07..bc0aebb 100644 --- a/hat/filters.py +++ b/hat/filters.py @@ -144,7 +144,7 @@ def filter_dataframe(df, filters: str): return df -def filter_timeseries(sims_ds: xr.Dataset, obs_ds: xr.Dataset, threshold=80): +def filter_timeseries(sims_ds: xr.DataArray, obs_ds: xr.DataArray, threshold=80): """Clean the simulation and observation timeseries Only keep.. @@ -159,30 +159,36 @@ def filter_timeseries(sims_ds: xr.Dataset, obs_ds: xr.Dataset, threshold=80): matching_stations = sorted( set(sims_ds.station.values).intersection(obs_ds.station.values) ) + print(len(matching_stations)) sims_ds = sims_ds.sel(station=matching_stations) obs_ds = obs_ds.sel(station=matching_stations) + obs_ds = obs_ds.sel(time=sims_ds.time) + + obs_ds = obs_ds.dropna(dim="station", how="all") + sims_ds = sims_ds.sel(station=obs_ds.station) # Only keep observations in the same time period as the simulations - obs_ds = obs_ds.where(sims_ds.time == obs_ds.time, drop=True) + # obs_ds = obs_ds.where(sims_ds.time == obs_ds.time, drop=True) # Only keep obsevations with enough valid data in this timeperiod # discharge data - dis = obs_ds.obsdis + print(sims_ds) + print(obs_ds) # Replace negative values with NaN - dis = dis.where(dis >= 0) + # dis = dis.where(dis >= 0) - # Percentage of valid discharge data at each point in time - valid_percent = dis.notnull().mean(dim="time") * 100 + # # Percentage of valid discharge data at each point in time + # valid_percent = dis.notnull().mean(dim="time") * 100 - # Boolean index of where there is enough valid data - enough_observation_data = valid_percent > threshold + # # Boolean index of where there is enough valid data + # enough_observation_data = valid_percent > threshold - # keep where there is enough observation data - obs_ds = obs_ds.where(enough_observation_data, drop=True) + # # keep where there is enough observation data + # obs_ds = obs_ds.where(enough_observation_data, drop=True) - # keep simulation that match remaining observations - sims_ds = sims_ds.where(enough_observation_data, drop=True) + # # keep simulation that match remaining observations + # sims_ds = sims_ds.where(enough_observation_data, drop=True) return (sims_ds, obs_ds) diff --git a/hat/graphs.py b/hat/graphs.py deleted file mode 100644 index 8f6ba39..0000000 --- a/hat/graphs.py +++ /dev/null @@ -1,30 +0,0 @@ -import pandas as pd -import plotly.express as px - - -def graph_sims_and_obs( - sims, - obs, - ID, - sims_data_name="simulation_timeseries", - obs_data_name="obsdis", - height=500, - width=1200, -): - # observations, simulations, time - o = obs.sel(station=ID)[obs_data_name].values - s = sims.sel(station=ID)[sims_data_name].values - t = obs.sel(station=ID).time.values - - df = pd.DataFrame({"time": t, "simulations": s, "observations": o}) - fig = px.line( - df, - x="time", - y=["simulations", "observations"], - title="Simulations & Observations", - ) - fig.data[0].line.color = "#34eb7d" - fig.data[1].line.color = "#3495eb" - fig.update_layout(height=height, width=width) - fig.update_yaxes(title_text="discharge") - fig.show() diff --git a/hat/hydrostats.py b/hat/hydrostats.py index d03c3fe..79238a9 100644 --- a/hat/hydrostats.py +++ b/hat/hydrostats.py @@ -1,30 +1,25 @@ """high level python api for hydrological statistics""" from typing import List -import folium -import geopandas as gpd -import pandas as pd +import numpy as np import xarray as xr -from branca.colormap import linear -from folium.plugins import Fullscreen from hat import hydrostats_functions def run_analysis( functions: List, - sims_ds: xr.Dataset, - obs_ds: xr.Dataset, - sims_var_name="simulation_timeseries", - obs_var_name="obsdis", + sims_ds: xr.DataArray, + obs_ds: xr.DataArray, ) -> xr.Dataset: - """Run statistical analysis on simulation and observation timeseries""" + """ + Run statistical analysis on simulation and observation timeseries + """ # list of stations stations = sims_ds.coords["station"].values - # Create an empty DataFrame with stations as the index - df = pd.DataFrame(index=stations) + ds = xr.Dataset() # For each statistical function for name in functions: @@ -33,95 +28,19 @@ def run_analysis( # do timeseries analysis for each station # (using a "numpy in, numpy out" function) - statistics = {} + statistics = [] for station in stations: - sims = sims_ds.sel(station=station)[sims_var_name].to_numpy() - obs = obs_ds.sel(station=station)[obs_var_name].to_numpy() - statistics[station] = func(sims, obs) + sims = sims_ds.sel(station=station).to_numpy() + obs = obs_ds.sel(station=station).to_numpy() - # 1D Series of statistics (i.e. scalar value per station) - statistics_series = pd.Series(statistics, name=name) + stat = func(sims, obs) + if stat is None: + print(f"Warning! All NaNs for station {station}") + stat = 0 + statistics += [stat] + statistics = np.array(statistics) # Add the Series to the DataFrame - df[name] = statistics_series - - # Convert the DataFrame to an xarray Dataset - ds = df.to_xarray() - ds = ds.rename({"index": "station"}) - ds["longitude"] = ("station", sims_ds.coords["longitude"].data) - ds["latitude"] = ("station", sims_ds.coords["latitude"].data) + ds[name] = xr.DataArray(statistics, coords={"station": stations}) return ds - - -def display_map(ds: xr.Dataset, name: str, minv: float = 0, maxv: float = 1): - # xarray to geopandas - gdf = gpd.GeoDataFrame( - ds.to_dataframe(), - geometry=gpd.points_from_xy(ds["longitude"], ds["latitude"]), - crs="epsg:4326", - ) - gdf["station_id"] = gdf.index - gdf = gdf[~gdf[name].isnull()] - - # Create a color map - colormap = linear.Blues_09.scale(minv, maxv) - - # Define style function - def style_function(feature): - property = feature["properties"][name] - return {"fillOpacity": 0.7, "weight": 0, "fillColor": colormap(property)} - - m = folium.Map(location=[48, 5], zoom_start=5, prefer_canvas=True, tiles=None) - _ = folium.GeoJson( - gdf, - marker=folium.CircleMarker(fillColor="white", fillOpacity=0.5, radius=5), - name=name, - style_function=style_function, - tooltip=folium.GeoJsonTooltip(fields=["station_id", name]), - popup=folium.GeoJsonPopup(fields=["station_id", name]), - ).add_to(m) - - # Add the CartoDB Positron tileset as a layer - cartodb_positron = folium.TileLayer( - tiles="CartoDB Dark_Matter", - name="Dark", - overlay=False, - control=True, - ) - cartodb_positron.add_to(m) - - # Add the CartoDB Positron tileset as a layer - cartodb_positron = folium.TileLayer( - tiles="CartoDB Positron", - name="Light", - overlay=False, - control=True, - ) - cartodb_positron.add_to(m) - - # Add OpenStreetMap layer - open_street_map = folium.TileLayer( - tiles="OpenStreetMap", - name="Open Street Map", - overlay=False, - control=True, - ) - open_street_map.add_to(m) - - # Add the satellite layer - esri_satellite = folium.TileLayer( - tiles="""https://server.arcgisonline.com/ArcGIS/rest/services/ - World_Imagery/MapServer/tile/{z}/{y}/{x}""", - attr="Esri", - name="Satellite", - overlay=False, - control=True, - ) - esri_satellite.add_to(m) - - # add controls - folium.LayerControl().add_to(m) - Fullscreen().add_to(m) - - return m diff --git a/hat/hydrostats_decorators.py b/hat/hydrostats_decorators.py index cfe6817..e0a219a 100644 --- a/hat/hydrostats_decorators.py +++ b/hat/hydrostats_decorators.py @@ -49,7 +49,7 @@ def wrapper(arr1: np.ndarray, arr2: np.ndarray): nan_mask2 = np.isnan(arr2) filtered_mask = ~(nan_mask1 | nan_mask2) if not np.any(filtered_mask): - raise ValueError("All elements are NaN") + return None return func(arr1[filtered_mask], arr2[filtered_mask]) return wrapper diff --git a/hat/images.py b/hat/images.py deleted file mode 100644 index 14ebeae..0000000 --- a/hat/images.py +++ /dev/null @@ -1,29 +0,0 @@ -import matplotlib -import numpy as np -from quicklook import quicklook - - -def arr_to_image(arr: np.array) -> np.array: - """modify array so that it is optimized for viewing""" - - # image array - img = np.array(arr) - - img = quicklook.replace_nan(img) - img = quicklook.percentile_clip(img, 2) - img = quicklook.bytescale(img) - img = quicklook.reshape_array(img) - - return img - - -def numpy_to_png( - arr: np.array, dim="time", index="somedate", fpath="image.png" -) -> None: - """Save numpy array to png""" - - # image from array - img = arr_to_image(arr) - - # save to file - matplotlib.image.imsave(fpath, img) diff --git a/hat/interactive/__init__.py b/hat/interactive/__init__.py new file mode 100644 index 0000000..99567ee --- /dev/null +++ b/hat/interactive/__init__.py @@ -0,0 +1,12 @@ +# flake8: noqa +from .explorers import TimeSeriesExplorer +from .leaflet import LeafletMap, PyleafletColormap +from .widgets import ( + DataFrameWidget, + HTMLTableWidget, + MetaDataWidget, + PlotlyWidget, + StatisticsWidget, + Widget, + WidgetsManager, +) diff --git a/hat/interactive/explorers.py b/hat/interactive/explorers.py new file mode 100644 index 0000000..0c88d4c --- /dev/null +++ b/hat/interactive/explorers.py @@ -0,0 +1,358 @@ +import os + +import ipywidgets +import pandas as pd +import xarray as xr +from IPython.display import display + +from hat.interactive.leaflet import LeafletMap, PyleafletColormap +from hat.interactive.widgets import ( + MetaDataWidget, + PlotlyWidget, + StatisticsWidget, + WidgetsManager, +) +from hat.observations import read_station_metadata_file + + +def prepare_simulations_data(simulations, sims_var_name): + """ + Process simulations and put then in a dictionnary of xarray data arrays. + + Parameters + ---------- + simulations : dict + A dictionary of paths to the simulation netCDF files, with the keys + being the simulation names. + sims_var_name : str + The name of the variable in the simulation netCDF files that contains + the simulated values. + + Returns + ------- + dict + A dictionary of xarray data arrays containing the simulation data. + + """ + # If simulations is a dictionary, load data for each experiment + sim_ds = {} + for exp, path in simulations.items(): + # Expanding the tilde + expanded_path = os.path.expanduser(path) + + if os.path.isfile(expanded_path): # Check if it's a file + ds = xr.open_dataset(expanded_path) + + sim_ds[exp] = ds[sims_var_name] + + return sim_ds + + +def prepare_observations_data(observations, sim_ds, obs_var_name): + """ + Process observation raw dataset to a standard xarray dataset. + The observation dataset can be either a csv file or a netcdf file. + The observation dataset is subsetted based on the time values of the + simulation dataset. + + Parameters + ---------- + observations : str + The path to the observation netCDF file. + sim_ds : dict or xarray.Dataset + A dictionary of xarray datasets containing the simulation data, or a + single xarray dataset. + obs_var_name : str + The name of the variable in the observation netCDF file that contains + the observed values. + + Returns + ------- + xarray.Dataset + An xarray dataset containing the observation data. + + Raises + ------ + ValueError + If the file format of the observations file is not supported. + + """ + file_extension = os.path.splitext(observations)[-1].lower() + + if file_extension == ".csv": + obs_df = pd.read_csv(observations, parse_dates=["Timestamp"]) + obs_melted = obs_df.melt( + id_vars="Timestamp", var_name="station", value_name=obs_var_name + ) + # Convert the melted DataFrame to xarray Dataset + obs_ds = obs_melted.set_index(["Timestamp", "station"]).to_xarray() + obs_ds = obs_ds.rename({"Timestamp": "time"}) + elif file_extension == ".nc": + obs_ds = xr.open_dataset(observations) + else: + raise ValueError("Unsupported file format for observations.") + + # Subset obs_ds based on sim_ds time values + if isinstance(sim_ds, xr.Dataset): + time_values = sim_ds["time"].values + elif isinstance(sim_ds, dict): + # Use the first dataset in the dictionary to determine time values + first_dataset = next(iter(sim_ds.values())) + time_values = first_dataset["time"].values + else: + raise ValueError("Unexpected type for sim_ds") + + obs_ds = obs_ds[obs_var_name].sel(time=time_values) + return obs_ds + + +def find_common_stations(station_index, stations_metadata, obs_ds, sim_ds, statistics): + """ + Find common stations between observations, simulations and station + metadata. + + Parameters + ---------- + station_index : str + The name of the column in the station metadata file that contains the + station IDs. + stations_metadata : pandas.DataFrame + A pandas DataFrame containing the station metadata. + obs_ds : xarray.Dataset + An xarray dataset containing the observation data. + sim_ds : dict or xarray.Dataset + A dictionary of xarray data arrays containing the simulation data. + statistics : dict + A dictionary of xarray data arrays containing the statistics data. + + Returns + ------- + list + A list of common station IDs. + + """ + ids = [] + ids += [list(obs_ds["station"].values)] + ids += [list(ds["station"].values) for ds in sim_ds.values()] + ids += [stations_metadata[station_index]] + if statistics: + ids += [list(ds["station"].values) for ds in statistics.values()] + + common_ids = None + for id in ids: + if common_ids is None: + common_ids = set(id) + else: + common_ids = set(id) & common_ids + return list(common_ids) + + +class TimeSeriesExplorer: + """ + Initialize the interactive map with configurations and data sources. + """ + + def __init__(self, config): + """ + Initializes an instance of the Explorer class. + + Parameters + ---------- + config : dict + A dictionary containing the configuration parameters for the + Explorer. + + Notes + ----- + This method initializes an instance of the Explorer class with the + given configuration parameters. + The configuration parameters should be provided as a dictionary with + the following keys: + + - stations : str + The path to the station metadata file. + - observations : str + The path to the observation netCDF file. + - simulations : dict + A dictionary of paths to the simulation netCDF files, with the keys + being the simulation names. + - statistics : dict, optional + A dictionary of paths to the statistics netCDF files, with the keys + being the simulation names. + - station_coordinates : list of str + The names of the columns in the station metadata file that contain + the station coordinates. + - station_epsg : int + The EPSG code of the coordinate reference system used by the + station coordinates. + - station_filters : dict + A dictionary of filters to apply to the station metadata file. + - sims_var_name : str + The name of the variable in the simulation netCDF files that + contains the simulated values. + - obs_var_name : str + The name of the variable in the observation netCDF file that + contains the observed values. + - station_id_column_name : str + The name of the column in the station metadata file that contains + the station IDs. + + Raises + ------ + AssertionError + If there is a mismatch between the keys of the statistics netCDF + files and the simulation netCDF files. + + """ + self.config = config + + self.stations_metadata = read_station_metadata_file( + fpath=config["stations"], + coord_names=config["station_coordinates"], + epsg=config["station_epsg"], + filters=config["station_filters"], + ) + + # Use the external functions to prepare data + sim_ds = prepare_simulations_data( + config["simulations"], config["sims_var_name"] + ) + obs_ds = prepare_observations_data( + config["observations"], sim_ds, config["obs_var_name"] + ) + + # set station index + self.station_index = config["station_id_column_name"] + + # Retrieve statistics from the statistics netcdf input + self.statistics = {} + stats = config.get("statistics") + if stats is not None: + for name, path in stats.items(): + self.statistics[name] = xr.open_dataset(path) + + # Ensure the keys of self.statistics match the keys of self.sim_ds + assert set(self.statistics.keys()) == set( + sim_ds.keys() + ), "Mismatch between statistics and simulations keys." + + # find common station ids between metadata, observation and simulations + common_ids = find_common_stations( + self.station_index, + self.stations_metadata, + obs_ds, + sim_ds, + self.statistics, + ) + + print(f"Found {len(common_ids)} common stations") + self.stations_metadata = self.stations_metadata.loc[ + self.stations_metadata[self.station_index].isin(common_ids) + ] + obs_ds = obs_ds.sel(station=common_ids) + for sim, ds in sim_ds.items(): + sim_ds[sim] = ds.sel(station=common_ids) + + # Create loading widget + self.loading_widget = ipywidgets.Label(value="") + + # Title label + self.title_label = ipywidgets.Label( + "Interactive Map Visualisation for Hydrological Model Performance", + layout=ipywidgets.Layout(justify_content="center"), + style={"font_weight": "bold", "font_size": "24px", "font_family": "Arial"}, + ) + + # Create the interactive widgets + datasets = sim_ds + datasets["obs"] = obs_ds + widgets = {} + widgets["plot"] = PlotlyWidget(datasets) + widgets["stats"] = StatisticsWidget(self.statistics) + widgets["meta"] = MetaDataWidget(self.stations_metadata, self.station_index) + self.widgets = WidgetsManager( + widgets, config["station_id_column_name"], self.loading_widget + ) + + # Create the main leaflet map + self.leafletmap = LeafletMap() + + def create_frame(self): + """ + Initialize the layout of the widgets for the map visualization. + + Returns + ------- + ipywidgets.VBox + A vertical box containing the layout elements for the map + visualization. + + """ + # Layouts 2 + main_layout = ipywidgets.Layout( + justify_content="space-around", + align_items="stretch", + spacing="2px", + width="1000px", + ) + left_layout = ipywidgets.Layout( + justify_content="space-around", + align_items="center", + spacing="2px", + width="40%", + ) + right_layout = ipywidgets.Layout( + justify_content="center", align_items="center", spacing="2px", width="60%" + ) + + # Frames + top_left_frame = self.leafletmap.output(left_layout) + top_right_frame = ipywidgets.VBox( + [self.widgets["plot"].output, self.widgets["stats"].output], + layout=right_layout, + ) + main_top_frame = ipywidgets.HBox([top_left_frame, top_right_frame]) + + # Main layout + main_frame = ipywidgets.VBox( + [self.title_label, main_top_frame, self.widgets["meta"].output], + layout=main_layout, + ) + return main_frame + + def plot(self, colorby=None, sim=None, limits=None, mp_colormap="viridis"): + """ + Plot the stations markers colored by a given metric. + + Parameters + ---------- + colorby : str, optional + The name of the metric to color the stations by. + sim : str, optional + The name of the simulation to use for the metric. + limits : list, optional + A list of two values representing the minimum and maximum values + for the color bar. + mp_colormap : str, optional + The name of the matplotlib colormap to use for the color bar. + + """ + # create colormap from statistics + stats = None + if self.statistics and colorby is not None and sim is not None: + stats = self.statistics[sim][colorby] + colormap = PyleafletColormap(self.config, stats, mp_colormap, limits) + + # add layer to the leaflet map + self.leafletmap.add_geolayer( + self.stations_metadata, + colormap, + self.widgets, + self.config["station_coordinates"], + ) + + # Initialize frame elements + frame = self.create_frame() + + # Display the main layout + display(frame) diff --git a/hat/interactive/leaflet.py b/hat/interactive/leaflet.py new file mode 100644 index 0000000..10ff223 --- /dev/null +++ b/hat/interactive/leaflet.py @@ -0,0 +1,248 @@ +import json + +import ipyleaflet +import ipywidgets +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np + + +def _compute_bounds(stations_metadata, coord_names): + """Compute the bounds of the map based on the stations metadata.""" + + lon_column = coord_names[0] + lat_column = coord_names[1] + + lons = stations_metadata[lon_column].values + lats = stations_metadata[lat_column].values + + min_lat, max_lat = min(lats), max(lats) + min_lon, max_lon = min(lons), max(lons) + + return [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))] + + +class LeafletMap: + """ + A class for creating interactive leaflet maps. + + Parameters + ---------- + basemap : ipyleaflet.basemaps, optional + The basemap to use for the map. Default is + ipyleaflet.basemaps.OpenStreetMap.Mapnik. + + """ + + def __init__( + self, + basemap=ipyleaflet.basemaps.OpenStreetMap.Mapnik, + ): + self.map = ipyleaflet.Map( + basemap=basemap, layout=ipywidgets.Layout(width="100%", height="600px") + ) + self.legend_widget = ipywidgets.Output() + + def _set_boundaries(self, stations_metadata, coord_names): + """ + Compute the boundaries of the map based on the stations metadata. + """ + lon_column = coord_names[0] + lat_column = coord_names[1] + + lons = stations_metadata[lon_column].values + lats = stations_metadata[lat_column].values + + min_lat, max_lat = min(lats), max(lats) + min_lon, max_lon = min(lons), max(lons) + + bounds = [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))] + self.map.fit_bounds(bounds) + + def add_geolayer(self, geodata, colormap, widgets, coord_names=None): + """ + Add a geolayer to the map. + + Parameters + ---------- + geodata : geopandas.GeoDataFrame + The geodataframe containing the geospatial data. + colormap : hat.PyleafletColormap + The colormap to use for the geolayer. + widgets : hat.WidgetsManager + The widgets to use for the geolayer. + coord_names : list of str, optional + The names of the columns containing the spatial coordinates. + Default is None. + """ + geojson = ipyleaflet.GeoJSON( + data=json.loads(geodata.to_json()), + style={ + "radius": 7, + "opacity": 0.5, + "weight": 1.9, + "dashArray": "2", + "fillOpacity": 0.5, + }, + hover_style={"radius": 10, "fillOpacity": 1}, + point_style={"radius": 5}, + style_callback=colormap.style_callback(), + ) + geojson.on_click(widgets.update) + self.map.add(geojson) + + if coord_names is not None: + self._set_boundaries(geodata, coord_names) + + self.legend_widget = colormap.legend() + + def output(self, layout={}): + """ + Return the output widget. + + Parameters + ---------- + layout : ipywidgets.Layout + The layout of the widget. + + Returns + ------- + ipywidgets.VBox + The output widget. + + """ + output = ipywidgets.VBox([self.map, self.legend_widget], layout=layout) + return output + + +class PyleafletColormap: + """ + A class handling the colormap of a pyleaflet map. + + Parameters + ---------- + config : dict + A dictionary containing configuration options for the map. + stats : xarray.Dataset or None, optional + A dataset containing the data to be plotted on the map. + If None, a default constant colormap will be used. + colormap_style : str, optional + The name of the matplotlib colormap to use. Default is 'viridis'. + range : tuple of float, optional + The minimum and maximum values of the colormap. If None, the + minimum and maximum values in `stats` will be used. + """ + + def __init__( + self, + config={}, + stats=None, + colormap_style="viridis", + range=None, + empty_color="white", + default_color="blue", + ): + self.config = config + self.stats = stats + self.empty_color = empty_color + self.default_color = default_color + if self.stats is not None: + assert ( + "station_id_column_name" in self.config + ), 'Config must contain "station_id_column_name"' + # Normalize the data for coloring + if range is None: + self.min_val = self.stats.values.min() + self.max_val = self.stats.values.max() + else: + self.min_val = range[0] + self.max_val = range[1] + else: + self.min_val = 0 + self.max_val = 1 + + try: + self.colormap = mpl.colormaps[colormap_style] + except KeyError: + raise KeyError( + f"Colormap {colormap_style} not found. " + f"Available colormaps are: {mpl.colormaps}" + ) + + def style_callback(self): + """ + Returns a function that can be used as input style for the ipyleaflet + layer. + + Returns + ------- + function + A function that takes a dataframe feature as input and returns a + dictionary of style options for the ipyleaflet layer. + """ + if self.stats is not None: + norm = plt.Normalize(self.min_val, self.max_val) + + def map_color(feature): + station_id = feature["properties"][ + self.config["station_id_column_name"] + ] + if station_id in self.stats.station.values: + station_stats = self.stats.sel(station=station_id) + color = mpl.colors.rgb2hex( + self.colormap(norm(station_stats.values)) + ) + else: + color = self.empty_color + style = { + "color": "black", + "fillColor": color, + } + return style + + else: + + def map_color(feature): + return { + "color": "black", + "fillColor": self.default_color, + } + + return map_color + + def legend(self): + """ + Generates an HTML legend for the colormap. + + Returns + ------- + ipywidgets.HTML + An HTML widget containing the colormap legend. + """ + # Convert the colormap to a list of RGB values + rgb_values = [ + mpl.colors.rgb2hex(self.colormap(i)) for i in np.linspace(0, 1, 256) + ] + + # Create a gradient style using the RGB values + gradient_style = ", ".join(rgb_values) + gradient_html = f""" +
+ """ + + # Create labels + labels_html = f""" +
+ Low: {self.min_val:.1f} + High: {self.max_val:.1f} +
+ """ + # Combine gradient and labels + legend_html = gradient_html + labels_html + + return ipywidgets.HTML(legend_html) diff --git a/hat/interactive/widgets.py b/hat/interactive/widgets.py new file mode 100644 index 0000000..ba7c5d0 --- /dev/null +++ b/hat/interactive/widgets.py @@ -0,0 +1,495 @@ +import time + +import numpy as np +import pandas as pd +import plotly.graph_objs as go +from IPython.display import clear_output, display +from ipywidgets import HTML, DatePicker, HBox, Label, Layout, Output, VBox + + +class ThrottledClick: + """ + Initialize a click throttler with a given delay. + + Parameters + ---------- + delay : float, optional + The delay in seconds between clicks. Defaults to 1.0. + + Notes + ----- + This class is used to prevent users from rapidly clicking a button or widget + multiple times, which can cause the application to crash or behave unexpectedly. + + Examples + -------- + >>> click_throttler = ThrottledClick(delay=0.5) + >>> if click_throttler.should_process(): + ... # do something + """ + + def __init__(self, delay=1.0): + self.delay = delay + self.last_call = 0 + + def should_process(self): + """ + Determine if a click should be processed based on the delay. + + Returns + ------- + bool + True if the click should be processed, False otherwise. + + Notes + ----- + This method should be called before processing a click event. If the + time since the last click is greater than the delay, the method returns + True and updates the last_call attribute. Otherwise, it returns False. + """ + current_time = time.time() + if current_time - self.last_call > self.delay: + self.last_call = current_time + return True + return False + + +class WidgetsManager: + """ + A class for managing a collection of widgets and updating following + a user interaction, providing an index. + + Parameters + ---------- + widgets : dict + A dictionary of widgets to manage. + index_column : str + The name of the column containing the index used to update the widgets. + loading_widget : optional + A widget to display a loading message while data is being loaded. + + Attributes + ---------- + widgets : dict + A dictionary of widgets being managed. + index_column : str + The name of the column containing the index used to update the widgets. + throttler : ThrottledClick + A throttler for click events. + loading_widget : optional + A widget to display a loading message while data is being loaded. + """ + + def __init__(self, widgets, index_column, loading_widget=None): + self.widgets = widgets + self.index_column = index_column + self.throttler = ThrottledClick() + self.loading_widget = loading_widget + + def __getitem__(self, item): + return self.widgets[item] + + def update(self, feature, **kwargs): + """ + Handle the selection of a marker on the map. + + Parameters + ---------- + feature : dict + A dictionary containing information about the selected feature. + **kwargs : dict + Additional keyword arguments to pass to the widgets update method. + """ + + # Check if we should process the click + if not self.throttler.should_process(): + return + + if self.loading_widget is not None: + self.loading_widget.value = ( + "Loading..." # Indicate that data is being loaded + ) + + # Extract station_id from the selected feature + metadata = feature["properties"] + index = metadata[self.index_column] + + # update widgets + for wgt in self.widgets.values(): + wgt.update(index, metadata, **kwargs) + + if self.loading_widget is not None: + self.loading_widget.value = "" # Clear the loading message + + +class Widget: + """ + A base class for interactive widgets. + + Parameters + ---------- + output : Output + The ipywidget compatible object to display the widget's content. + + Attributes + ---------- + output : Output + The ipywidget compatible object to display the widget's content. + + Methods + ------- + update(index, metadata, **kwargs) + Update the widget's content based on the given index and metadata. + """ + + def __init__(self, output): + self.output = output + + def update(self, index, *args, **kwargs): + raise NotImplementedError + + +def _filter_nan_values(dates, data_values): + """ + Filters out NaN values and their associated dates. + """ + assert len(dates) == len( + data_values + ), "Dates and data values must be the same length." + valid_dates = [date for date, val in zip(dates, data_values) if not np.isnan(val)] + valid_data = [val for val in data_values if not np.isnan(val)] + + return valid_dates, valid_data + + +class PlotlyWidget(Widget): + """ + A widget to display timeseries data using Plotly. + + Parameters + ---------- + datasets : dict + A dictionary containing the xarray timeseries datasets to be displayed. + + Attributes + ---------- + datasets : dict + A dictionary containing the xarray timeseries datasets to be displayed. + figure : plotly.graph_objs._figurewidget.FigureWidget + The Plotly figure widget. + ds_time_str : list + A list of strings representing the dates in the timeseries data. + start_date_picker : DatePicker + The date picker widget for selecting the start date. + end_date_picker : DatePicker + The date picker widget for selecting the end date. + """ + + def __init__(self, datasets): + self.datasets = datasets + self.figure = go.FigureWidget( + layout=go.Layout( + height=350, + margin=dict(l=120), + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + xaxis_title="Date", + xaxis_tickformat="%d-%m-%Y", + yaxis_title="Discharge [m3/s]", + ) + ) + ds_time = datasets["obs"]["time"].values.astype("datetime64[D]") + self.ds_time_str = [dt.isoformat() for dt in pd.to_datetime(ds_time)] + + self.start_date_picker = DatePicker(description="Start") + self.end_date_picker = DatePicker(description="End") + + self.start_date_picker.observe(self._update_plot_dates, names="value") + self.end_date_picker.observe(self._update_plot_dates, names="value") + + date_label = Label( + "Please select the date to accurately change the date axis of the plot" + ) + date_picker_box = HBox([self.start_date_picker, self.end_date_picker]) + + layout = Layout(justify_content="center", align_items="center") + output = VBox([self.figure, date_label, date_picker_box], layout=layout) + super().__init__(output) + + def _update_plot_dates(self): + """ + Updates the plot with the selected start and end dates. + """ + start_date = self.start_date_picker.value.strftime("%Y-%m-%d") + end_date = self.end_date_picker.value.strftime("%Y-%m-%d") + self.figure.update_layout(xaxis_range=[start_date, end_date]) + + def _update_data(self, station_id): + """ + Updates the simulation data for the given station ID. + """ + for name, ds in self.datasets.items(): + if station_id in ds["station"].values: + ds_time_series_data = ds.sel(station=station_id).values + valid_dates_ds, valid_data_ds = _filter_nan_values( + self.ds_time_str, ds_time_series_data + ) + self._update_trace(valid_dates_ds, valid_data_ds, name) + else: + print(f"Station ID: {station_id} not found in dataset {name}.") + return False + return True + + def _update_trace(self, x_data, y_data, name): + """ + Updates the plot trace for the given name with the given x and y data. + """ + trace_exists = any([trace.name == name for trace in self.figure.data]) + if trace_exists: + for trace in self.figure.data: + if trace.name == name: + trace.x = x_data + trace.y = y_data + else: + self.figure.add_trace( + go.Scatter(x=x_data, y=y_data, mode="lines", name=name) + ) + + def _update_title(self, metadata): + """ + Updates the plot title following the point metadata. + """ + station_id = metadata["station_id"] + station_name = metadata["StationName"] + updated_title = ( + f"Selected station:
ID: {station_id}, name: {station_name}
" + ) + self.figure.update_layout( + title={ + "text": updated_title, + "y": 0.9, + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": {"color": "black", "size": 16}, + } + ) + + def update(self, index, *args, **kwargs): + """ + Updates the overall plot with new data for the given index. + + Parameters + ---------- + index : str + The ID of the station to update the data for. + metadata : dict + A dictionary containing the metadata for the selected station. + """ + return self._update_data(index) + + +class HTMLTableWidget(Widget): + """ + A widget to display a pandas dataframe with the HTML format. + + Parameters + ---------- + title : str + The title of the table. + """ + + def __init__(self, title): + self.title = title + super().__init__(Output()) + + # Define the styles for the statistics table + self.table_style = """ + + """ + self.stat_title_style = ( + "style='font-size: 18px; font-weight: bold; text-align: center;'" + ) + # Initialize the stat_table_html and station_table_html with empty tables + empty_df = pd.DataFrame() + self._display_dataframe_with_scroll(empty_df, title=self.title) + + def _display_dataframe_with_scroll(self, df, title=""): + table_html = df.to_html(classes="custom-table") + content = f"{self.table_style}

{title}

{table_html}
" # noqa: E501 + with self.output: + clear_output(wait=True) # Clear any previous plots or messages + display(HTML(content)) + + def update(self, index, *args, **kwargs): + """ + Update the table with the dataframe as the given index. + + Parameters + ---------- + index : int + The index of the data to be displayed. + metadata : dict + The metadata associated with the data index. + """ + dataframe = self._extract_dataframe(index) + self._display_dataframe_with_scroll(dataframe, title=self.title) + if dataframe.empty: + return False + return True + + +class DataFrameWidget(Widget): + """ + A widget to display a pandas dataframe with the default pandas display + style. + + Parameters + ---------- + title : str + The title of the table. + """ + + def __init__(self, title): + self.title = title + super().__init__(output=Output(title=self.title)) + + # Initialize the stat_table_html and station_table_html with empty tables + empty_df = pd.DataFrame() + with self.output: + clear_output(wait=True) # Clear any previous plots or messages + display(empty_df) + + def update(self, index, *args, **kwargs): + """ + Update the table with the dataframe as the given index. + + Parameters + ---------- + index : int + The index of the data to be displayed. + metadata : dict + The metadata associated with the data index. + """ + dataframe = self._extract_dataframe(index) + with self.output: + clear_output(wait=True) # Clear any previous plots or messages + display(dataframe) + if dataframe.empty: + return False + return True + + def _extract_dataframe(self, index): + """ + Virtual method to return the object dataframe at the index. + """ + raise NotImplementedError + + +class MetaDataWidget(HTMLTableWidget): + """ + An extension of the HTMLTableWidget class to display a station metadata. + + Parameters + ---------- + dataframe : pd.DataFrame + A pandas dataframe to be displayed in the table. + station_index : str + Column name of the station index. + """ + + def __init__(self, dataframe, station_index): + title = "Station Metadata" + self.dataframe = dataframe + self.station_index = station_index + super().__init__(title) + + def _extract_dataframe(self, station_id): + stations_df = self.dataframe + selected_station_df = stations_df[stations_df[self.station_index] == station_id] + return selected_station_df + + +class StatisticsWidget(HTMLTableWidget): + """ + An extension of the HTMLTableWidget to display statistics at stations. + + Parameters + ---------- + dataframe : pd.DataFrame + A pandas dataframe to be displayed in the table. + station_index : str + Column name of the station index. + """ + + def __init__(self, statistics): + title = "Model Performance Statistics Overview" + self.statistics = statistics + for stat in self.statistics.values(): + assert ( + "station" in stat.dims + ), 'Dimension "station" not found in statistics datasets.' # noqa: E501 + super().__init__(title) + + def _extract_dataframe(self, station_id): + """Generate a statistics table for the given station ID.""" + data = [] + + # Check if statistics is None or empty + if not self.statistics: + print("No statistics data provided.") + return pd.DataFrame() # Return an empty dataframe + + # Loop through each simulation and get the statistics for the given station_id + for exp_name, stats in self.statistics.items(): + if station_id in stats["station"].values: + row = [exp_name] + [ + round(stats[var].sel(station=station_id).values.item(), 2) + for var in stats.data_vars + if var not in ["longitude", "latitude"] + ] + data.append(row) + + # Check if data has any items + if not data: + print(f"No statistics data found for station ID: {station_id}.") + return pd.DataFrame() # Return an empty dataframe + + # Convert the data to a DataFrame for display + columns = ["Exp. name"] + list(stats.data_vars.keys()) + statistics_df = pd.DataFrame(data, columns=columns) + + # Round the numerical columns to 2 decimal places + numerical_columns = [col for col in statistics_df.columns if col != "Exp. name"] + statistics_df[numerical_columns] = statistics_df[numerical_columns].round(2) + + return statistics_df diff --git a/hat/networking.py b/hat/networking.py deleted file mode 100644 index a5b04ab..0000000 --- a/hat/networking.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -import platform -import socket - - -def get_host(): - """Local host on Mac and network host on HPC - (note the network address on HPC is not constant)""" - - # get the hostname - hostname = socket.gethostname() - - # get the IP address(es) associated with the hostname - ip_addresses = socket.getaddrinfo( - hostname, None, socket.AF_INET, socket.SOCK_STREAM - ) - - # return first valid address - for ip_address in ip_addresses: - network_host = ip_address[4][0] - return network_host - - -def mac_or_hpc(): - """Is this running on a Mac or the HPC or other?""" - - if platform.system() == "Darwin": - return "mac" - elif platform.system() == "Linux" and os.environ.get("ECPLATFORM"): - return "hpc" - else: - return "other" - - -def host_and_port(host="127.0.0.1", port=8000): - """return network host and port for tiler app to use""" - - computer = mac_or_hpc() - - if computer == "hpc": - host = get_host() - port = 8700 - - return (host, port) diff --git a/hat/observations.py b/hat/observations.py index f895201..467eb70 100644 --- a/hat/observations.py +++ b/hat/observations.py @@ -37,9 +37,6 @@ def read_station_metadata_file( """read hydrological stations from file. will cache as pickle object because .csv file used by the team takes 12 seconds to load""" - print("station file") - print(fpath) - try: if is_csv(fpath): gdf = read_csv_and_cache(fpath) diff --git a/hat/parsers.py b/hat/parsers.py deleted file mode 100644 index 0c50b64..0000000 --- a/hat/parsers.py +++ /dev/null @@ -1,17 +0,0 @@ -import dateutil.parser -import streamlit as st - - -@st.cache_data -def datetime_from_cftime(cftimes): - """parse CFTimeIndex to python datetime, - e.g. from a NetCDF file ds.indexes['time']""" - return [dateutil.parser.parse(x.isoformat()) for x in cftimes] - - -def simulation_timeperiod(sim): - # simulation timeperiod - min_time = min(sim.indexes["time"]) - max_time = max(sim.indexes["time"]) - - return (min_time, max_time) diff --git a/hat/plots.py b/hat/plots.py deleted file mode 100644 index 3581409..0000000 --- a/hat/plots.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Union - -import numpy as np -import pandas as pd -import plotly.express as px -import streamlit as st -from matplotlib import pyplot as plt - - -# PLOTLY (interactive) -def plotly_timeseries(t, y): - df = pd.DataFrame({"time": t, "discharge": y}) - return px.line(df, x="time", y="discharge", title="Discharge Timeseries") - - -# MATPLOTLIB (not interactive) -def plot_timeseries(t, y, jupyter=False): - fig, ax1 = plt.subplots() - fig.set_size_inches(14, 6) - ax1.plot(t, y, "dodgerblue") - ax1.set_xlabel("time (s)") - ax1.set_ylabel("discharge", color="b") - ax1.tick_params("y", colors="b") - - if jupyter: - return fig - - st.write(fig) - - -def histogram( - arr: Union[np.array, np.ma.MaskedArray], - bins=10, - clip=None, - title="Histogram", - figsize=(6, 4), -): - """plot histogram of a numpy array or masked numpy array""" - - # apply mask (if one exists) - if isinstance(arr, np.ma.MaskedArray): - arr = arr.compressed() - - # return if not numpy - if not isinstance(arr, np.ndarray): - print("histogram() requires a numpy array or masked numpy array") - return - - # remove flat dimensions - arr = arr.squeeze() - - # remove nans - arr = arr[~np.isnan(arr)] - - # histogram range (percentile clip or minmax) - if clip: - histogram_range = ( - round(np.percentile(arr, clip)), - round(np.percentile(arr, 100 - clip)), - ) - else: - histogram_range = (np.min(arr), np.max(arr)) - - # count number of values in each bin - counts, bins = np.histogram(arr, bins=bins, range=histogram_range) - - _ = plt.figure(figsize=figsize) - plt.hist(bins[:-1], bins, weights=counts) - plt.title(title) - - # show plot - plt.show() diff --git a/hat/tools/hydrostats_cli.py b/hat/tools/hydrostats_cli.py index 5019b9c..204f92a 100644 --- a/hat/tools/hydrostats_cli.py +++ b/hat/tools/hydrostats_cli.py @@ -4,6 +4,7 @@ import xarray as xr from hat import hydrostats_functions +from hat.data import find_main_var from hat.exceptions import UserError from hat.filters import filter_timeseries from hat.hydrostats import run_analysis @@ -81,15 +82,19 @@ def hydrostats_cli( # simulations sims_ds = xr.open_dataset(sims) + var = find_main_var(sims_ds, min_dim=2) + sims_da = sims_ds[var] # observations obs_ds = xr.open_dataset(obs) + var = find_main_var(obs_ds, min_dim=2) + obs_da = obs_ds[var] # clean timeseries - sims_ds, obs = filter_timeseries(sims_ds, obs_ds, threshold=obs_threshold) + sims_da, obs_da = filter_timeseries(sims_da, obs_da, threshold=obs_threshold) # calculate statistics - statistics_ds = run_analysis(functions, sims_ds, obs_ds) + statistics_ds = run_analysis(functions, sims_da, obs_da) # save to netcdf statistics_ds.to_netcdf(outpath) diff --git a/hat/version.py b/hat/version.py index 3d18726..906d362 100644 --- a/hat/version.py +++ b/hat/version.py @@ -1 +1 @@ -__version__ = "0.5.0" +__version__ = "0.6.0" diff --git a/notebooks/examples/4_visualisation_interactive.ipynb b/notebooks/examples/4_visualisation_interactive.ipynb new file mode 100644 index 0000000..2194b6d --- /dev/null +++ b/notebooks/examples/4_visualisation_interactive.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import necessart visualisation library and define inputs simulation and additional vector to add" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "config = {\n", + " \"stations\": \"~/git/hat/data/outlets_v4.0_20230726_withEFAS.csv\",\n", + " \"observations\": \"~/git/hat/data/observations/destine_observations.nc\",\n", + " \"simulations\": {\n", + " \"i05j\": \"~/git/hat/data/cama_i05j_stations.nc\",\n", + " \"i05h\": \"~/git/hat/data/cama_i05h_stations.nc\",\n", + " },\n", + " \"statistics\": {\n", + " \"i05j\": \"~/git/hat/data/observations/statistics_i05j.nc\",\n", + " \"i05h\": \"~/git/hat/data/observations/statistics_i05h.nc\",\n", + " },\n", + " \"station_epsg\": 4326,\n", + " \"station_id_column_name\": \"station_id\",\n", + " \"station_filters\":\"\",\n", + " \"station_coordinates\": [\"StationLon\", \"StationLat\"],\n", + " \"obs_var_name\": \"obsdis\",\n", + " \"sims_var_name\": \"dis\",\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialise the map object from the configuration." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "Exception", + "evalue": "Could not open file ~/git/hat/data/outlets_v4.0_20230726_withEFAS.csv", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mCPLE_OpenFailedError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32mfiona/_shim.pyx:83\u001b[0m, in \u001b[0;36mfiona._shim.gdal_open_vector\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mfiona/_err.pyx:291\u001b[0m, in \u001b[0;36mfiona._err.exc_wrap_pointer\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mCPLE_OpenFailedError\u001b[0m: /home/macw/git/hat/data/outlets_v4.0_20230726_withEFAS.csv: No such file or directory", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mDriverError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/observations.py:42\u001b[0m, in \u001b[0;36mread_station_metadata_file\u001b[0;34m(fpath, coord_names, epsg, filters)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_csv(fpath):\n\u001b[0;32m---> 42\u001b[0m gdf \u001b[38;5;241m=\u001b[39m \u001b[43mread_csv_and_cache\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m gdf \u001b[38;5;241m=\u001b[39m add_geometry_column(gdf, coord_names)\n", + "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/data.py:234\u001b[0m, in \u001b[0;36mread_csv_and_cache\u001b[0;34m(fpath)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;66;03m# otherwise load from user defined filepath (and then cache)\u001b[39;00m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 234\u001b[0m gdf \u001b[38;5;241m=\u001b[39m \u001b[43mgpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 235\u001b[0m gdf\u001b[38;5;241m.\u001b[39mto_pickle(cache_fpath)\n", + "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/geopandas/io/file.py:259\u001b[0m, in \u001b[0;36m_read_file\u001b[0;34m(filename, bbox, mask, rows, engine, **kwargs)\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m engine \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfiona\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 259\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read_file_fiona\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 260\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath_or_bytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfrom_bytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbbox\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbbox\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrows\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrows\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m engine \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpyogrio\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/geopandas/io/file.py:303\u001b[0m, in \u001b[0;36m_read_file_fiona\u001b[0;34m(path_or_bytes, from_bytes, bbox, mask, rows, where, **kwargs)\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m fiona_env():\n\u001b[0;32m--> 303\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mreader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_or_bytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m features:\n\u001b[1;32m 304\u001b[0m crs \u001b[38;5;241m=\u001b[39m features\u001b[38;5;241m.\u001b[39mcrs_wkt\n", + "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/fiona/env.py:408\u001b[0m, in \u001b[0;36mensure_env_with_credentials..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m local\u001b[38;5;241m.\u001b[39m_env:\n\u001b[0;32m--> 408\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/fiona/__init__.py:264\u001b[0m, in \u001b[0;36mopen\u001b[0;34m(fp, mode, driver, schema, crs, encoding, layer, vfs, enabled_drivers, crs_wkt, **kwargs)\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mode \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[0;32m--> 264\u001b[0m c \u001b[38;5;241m=\u001b[39m \u001b[43mCollection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdriver\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdriver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43menabled_drivers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menabled_drivers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m:\n", + "File \u001b[0;32m/usr/local/apps/python3/3.10.10-01/lib/python3.10/site-packages/fiona/collection.py:162\u001b[0m, in \u001b[0;36mCollection.__init__\u001b[0;34m(self, path, mode, driver, schema, crs, encoding, layer, vsi, archive, enabled_drivers, crs_wkt, ignore_fields, ignore_geometry, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msession \u001b[38;5;241m=\u001b[39m Session()\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstart\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmode \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m):\n", + "File \u001b[0;32mfiona/ogrext.pyx:540\u001b[0m, in \u001b[0;36mfiona.ogrext.Session.start\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mfiona/_shim.pyx:90\u001b[0m, in \u001b[0;36mfiona._shim.gdal_open_vector\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mDriverError\u001b[0m: /home/macw/git/hat/data/outlets_v4.0_20230726_withEFAS.csv: No such file or directory", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mhat\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minteractive\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexplorers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TimeSeriesExplorer\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28mmap\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mTimeSeriesExplorer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/interactive/explorers.py:211\u001b[0m, in \u001b[0;36mTimeSeriesExplorer.__init__\u001b[0;34m(self, config)\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;124;03mInitializes an instance of the Explorer class.\u001b[39;00m\n\u001b[1;32m 160\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 207\u001b[0m \n\u001b[1;32m 208\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig \u001b[38;5;241m=\u001b[39m config\n\u001b[0;32m--> 211\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstations_metadata \u001b[38;5;241m=\u001b[39m \u001b[43mread_station_metadata_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 212\u001b[0m \u001b[43m \u001b[49m\u001b[43mfpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstations\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[43m \u001b[49m\u001b[43mcoord_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstation_coordinates\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[43m \u001b[49m\u001b[43mepsg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstation_epsg\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 215\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstation_filters\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 216\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;66;03m# Use the external functions to prepare data\u001b[39;00m\n\u001b[1;32m 219\u001b[0m sim_ds \u001b[38;5;241m=\u001b[39m prepare_simulations_data(\n\u001b[1;32m 220\u001b[0m config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msimulations\u001b[39m\u001b[38;5;124m\"\u001b[39m], config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msims_var_name\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 221\u001b[0m )\n", + "File \u001b[0;32m/etc/ecmwf/nfs/dh1_home_a/macw/git/hat/hat/observations.py:48\u001b[0m, in \u001b[0;36mread_station_metadata_file\u001b[0;34m(fpath, coord_names, epsg, filters)\u001b[0m\n\u001b[1;32m 46\u001b[0m gdf \u001b[38;5;241m=\u001b[39m gpd\u001b[38;5;241m.\u001b[39mread_file(fpath)\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not open file \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;66;03m# (optionally) filter the stations, e.g. 'Contintent == Europe'\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m filters \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[0;31mException\u001b[0m: Could not open file ~/git/hat/data/outlets_v4.0_20230726_withEFAS.csv" + ] + } + ], + "source": [ + "from hat.interactive.explorers import TimeSeriesExplorer\n", + "map = TimeSeriesExplorer(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display map and use the tool to show plot of the simulation time series." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "map.mapplot(colorby='kge', sim='i05j')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hat-venv", + "language": "python", + "name": "hat-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/examples/4_visualisation_interactive_template.ipynb b/notebooks/examples/4_visualisation_interactive_template.ipynb new file mode 100644 index 0000000..544522f --- /dev/null +++ b/notebooks/examples/4_visualisation_interactive_template.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Import visualisation library and define inputs (stations, observations, simulation, statistics, and also the config that may differ depending on the simulation)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from hat.visualisation import InteractiveMap\n", + "\n", + "stations= '~/destinE/outlets_v4.0_20230726_withEFAS.csv'\n", + "observations= '~/destinE/destine_observation.nc'\n", + "simulations = {\n", + " \"i05j\": \"~/destinE/cama_i05j_stations.nc\",\n", + " \"i05h\": \"~/destinE/cama_i05h_stations.nc\"\n", + "}\n", + "# xarray input\n", + "\n", + "statistics = {\n", + " \"i05j\": \"~/destinE/statistics_i05j.nc\",\n", + " \"i05h\": \"~/destinE/statistics_i05h.nc\"\n", + "}\n", + "# xarray input\n", + "\n", + "config = {\n", + " \"station_epsg\": 4326,\n", + " \"station_id_column_name\": \"station_id\",\n", + " \"station_filters\":\"\",\n", + " \"station_coordinates\": [\"StationLon\", \"StationLat\"],\n", + " \"obs_var_name\": \"obsdis\"\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. Initialise & processing all the input data into the map object" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1614 common stations\n" + ] + } + ], + "source": [ + "map = InteractiveMap(config, stations, observations, simulations, stats=statistics)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "3. Display map interface for interactive viewing" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "44784a504c27442990df5eb6a5f26a72", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='Interactive Map Visualisation for Hydrological Model Performance', layout=Layout(j…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "map.mapplot(colorby='kge', sim='i05h', range=[-1, 1]) " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hat-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_hydrostats_decorators.py b/tests/test_hydrostats_decorators.py index 3a53a38..c3839aa 100644 --- a/tests/test_hydrostats_decorators.py +++ b/tests/test_hydrostats_decorators.py @@ -63,10 +63,8 @@ def test_filter_nan(): assert np.allclose(decorated(nan3, nan3), np.array([2, 4])) # all nans - with pytest.raises(ValueError): - decorated(arr, nans) - with pytest.raises(ValueError): - decorated(nans, arr) + assert decorated(arr, nans) is None + assert decorated(nans, arr) is None # def test_handle_divide_by_zero_error(): @@ -124,8 +122,7 @@ def test_hydrostat(): # # all zero division # with pytest.raises(ZeroDivisionError): - # decorated_divide(ones, zeros) + # print(decorated_divide(ones, zeros)) # all nans - with pytest.raises(ValueError): - decorated_divide(nans, nans) + assert decorated_divide(nans, nans) is None diff --git a/tests/test_interactive.py b/tests/test_interactive.py new file mode 100644 index 0000000..de24d32 --- /dev/null +++ b/tests/test_interactive.py @@ -0,0 +1,215 @@ +import time + +import geopandas as gpd +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from shapely.geometry import Point + +from hat.interactive import leaflet as lf +from hat.interactive import widgets as wd + + +class TestThrottledClick: + def test_delay(self): + throttler = wd.ThrottledClick(0.01) + assert throttler.should_process() is True + assert throttler.should_process() is False + time.sleep(0.015) + assert throttler.should_process() is True + + +class DummyWidget(wd.Widget): + def __init__(self): + super().__init__(output=None) + self.index = None + + def update(self, index, metadata, **kwargs): + self.index = index + + +class TestWidgetsManager: + def test_update(self): + dummy = DummyWidget() + widgets = wd.WidgetsManager(widgets={"dummy": dummy}, index_column="station") + feature = { + "properties": { + "station": "A", + } + } + widgets.update(feature) + assert dummy.index == "A" + + +def test_filter_nan(): + dates = np.arange(10) + values = np.arange(10, dtype=float) + values[5] = np.nan + dates_new, values_new = wd._filter_nan_values(dates, values) + assert len(dates_new) == 9 + assert len(values_new) == 9 + + +class TestPlotlyWidget: + def test_update(self): + datasets = { + "obs": xr.DataArray( + [[0, 3, 6], [0, 3, 6]], + coords={ + "time": np.array( + ["2007-07-13", "2007-01-14"], dtype="datetime64[ns]" + ), + "station": [0, 1, 2], + }, + ), + "sim1": xr.DataArray( + [[1, 2, 3], [1, 2, 3]], + coords={ + "time": np.array( + ["2007-07-13", "2007-01-14"], dtype="datetime64[ns]" + ), + "station": [0, 1, 2], + }, + ), + "sim2": xr.DataArray( + [[4, 5, 6], [4, 5, 6]], + coords={ + "time": np.array( + ["2007-07-13", "2007-01-14"], dtype="datetime64[ns]" + ), + "station": [0, 1, 2], + }, + ), + } + widget = wd.PlotlyWidget( + datasets=datasets, + ) + assert widget.update(1) is True # if id found + assert widget.update(3) is False # if id not found + + +class TestMetaDataWidget: + def test_update(self): + df = pd.DataFrame( + {"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [0.1, 0.2, 0.3]} + ) + widget = wd.MetaDataWidget(df, "col2") + assert widget.update("a") is True + assert widget.update("f") is False + + +class TestStatisticsWidget: + def test_update(self): + datasets = { + "sim1": xr.Dataset( + {"data": [0, 1, 2]}, + coords={ + "station": [0, 1, 2], + }, + ), + "sim2": xr.Dataset( + {"data": [4, 5, 6]}, + coords={ + "station": [0, 1, 2], + }, + ), + } + widget = wd.StatisticsWidget(datasets) + assert widget.update(0) is True + assert widget.update(4) is False + + def test_fail_creation(self): + datasets = { + "sim1": xr.Dataset( + {"data": [0, 1, 2]}, + coords={ + "id": [0, 1, 2], + }, + ), + } + with pytest.raises(AssertionError): + wd.StatisticsWidget(datasets) + + +class TestPyleafletColormap: + config = { + "station_id_column_name": "station", + } + stats = xr.DataArray( + [0, 1, 2], + coords={ + "station": [0, 1, 2], + }, + ) + + def test_default_creation(self): + lf.PyleafletColormap() + + def test_creation_with_stats(self): + lf.PyleafletColormap( + self.config, self.stats, colormap_style="plasma", range=[1, 2] + ) + + def test_creation_wrong_colormap(self): + with pytest.raises(KeyError): + lf.PyleafletColormap( + self.config, self.stats, colormap_style="awdawd", range=[1, 2] + ) + + def test_fail_creation(self): + config = {} + with pytest.raises(AssertionError): + lf.PyleafletColormap(config, self.stats) + + def test_default_style(self): + colormap = lf.PyleafletColormap(self.config) + style_fct = colormap.style_callback() + style_fct(feature={}) + + def test_stats_style(self): + feature = { + "properties": { + "station": 2, + } + } + colormap = lf.PyleafletColormap(self.config, self.stats) + style_fct = colormap.style_callback() + style_fct(feature) + + def test_stats_style_fail(self): + feature = { + "properties": { + "station": 4, + } + } + colormap = lf.PyleafletColormap(self.config, self.stats, empty_color="black") + style_fct = colormap.style_callback() + style = style_fct(feature) + assert style["fillColor"] == "black" + + def test_stats_style_default(self): + colormap = lf.PyleafletColormap(default_color="blue") + style_fct = colormap.style_callback() + style = style_fct({}) + assert style["fillColor"] == "blue" + + def test_legend(self): + colormap = lf.PyleafletColormap(self.config, self.stats) + colormap.legend() + + +class TestLeafletMap: + def test_creation(self): + lf.LeafletMap() + + def test_output(self): + map = lf.LeafletMap() + map.output() + + def test_add_geolayer(self): + map = lf.LeafletMap() + widgets = {} + colormap = lf.PyleafletColormap() + gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)]) + map.add_geolayer(gdf, colormap, widgets)