diff --git a/pyposeidon/utils/obs.py b/pyposeidon/utils/obs.py index d454f668..2f9e27f9 100644 --- a/pyposeidon/utils/obs.py +++ b/pyposeidon/utils/obs.py @@ -9,6 +9,36 @@ import numpy as np import xarray as xr from datetime import datetime +import itertools +from typing import Iterable +from typing import Iterator + + +def grouper( + iterable: Iterable[_T], + n: int, + *, + incomplete: str = "fill", + fillvalue: Union[_U, None] = None, +) -> Iterator[Tuple[Union[_T, _U], ...]]: + """Collect data into non-overlapping fixed-length chunks or blocks""" + # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx + # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError + # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF + args = [iter(iterable)] * n + if incomplete == "fill": + return itertools.zip_longest(*args, fillvalue=fillvalue) + if incomplete == "strict": + return zip(*args, strict=True) # type: ignore[call-overload] + if incomplete == "ignore": + return zip(*args) + else: + raise ValueError("Expected fill, strict, or ignore") + + +def merge_datasets(datasets: List[xr.Dataset], size: int = 5) -> List[xr.Dataset]: + datasets = [xr.merge(g for g in group if g) for group in grouper(datasets, size)] + return datasets def get_bogus(temp, idfs, ntime): @@ -52,7 +82,12 @@ def get_obs_data(stations: str | gp.GeoDataFrame, start_time=None, end_time=None xdfs = [] for idfs in sn.ioc_code: xdfs.append(get_bogus(temp, idfs, ntime)) - # merge + + # in order to keep memory consumption low, let's group the datasets + # and merge them in batches + while len(xdfs) > 5: + xdfs = merge_datasets(xdfs) + # Do the final merging sns = xr.merge(xdfs) # set correct dims for metadata diff --git a/pyposeidon/utils/post.py b/pyposeidon/utils/post.py index 81ef1925..22c1b825 100644 --- a/pyposeidon/utils/post.py +++ b/pyposeidon/utils/post.py @@ -50,10 +50,10 @@ def get_encoding(ename): } -def save_leads(stations, st, start_date, dt, freq, rpath="./skill/"): +def save_leads(stations, st, start_date, dt, leads, rpath="./skill/"): ## save results in lead chunks - for l in range(freq): + for l in range(leads): from_date = start_date + pd.to_timedelta("{}H".format(l * dt)) + pd.to_timedelta("1S") to_date = start_date + pd.to_timedelta("{}H".format((l + 1) * dt)) h = st.sel(time=slice(from_date, to_date), drop=True) @@ -106,15 +106,16 @@ def gather_obs_data(stations, start_time, end_time, rpath="./thalassa/obs/"): return -def compute_stats(st, rpath="./thalassa/obs/"): +def to_stats(st, rpath="./thalassa/", opath="./thalassa/obs/"): logger.info("compute general statistics for station points\n") + ids = st.id.values + sts = [] - for inode in tqdm(st.id.values): - isim = st.where(st.id == inode).dropna(dim="id") - sim = isim.elev.to_dataframe().droplevel(1) + for inode in tqdm(ids): + sim = st.sel(id=inode).elev_sim.to_dataframe().drop("id", axis=1) - filename = os.path.join(rpath, f"{inode}.parquet") + filename = os.path.join(opath, f"{inode}.parquet") obs = pd.read_parquet(filename, engine="fastparquet") obs_ = obs.dropna(axis=1, how="all") # drop all nan columns @@ -124,35 +125,29 @@ def compute_stats(st, rpath="./thalassa/obs/"): cols = [x for x in obs_.columns if x in valid_sensors] obs = obs_[[cols[0]]] # just choose one for now + if obs.dropna().empty: + logger.warning(f"Observation data not available for {inode} station") + stable = get_stats(sim, obs) # Do general statitics sts.append(stable) - return sts - - -def save_stats(sts, stations, **kwargs): - rpath = kwargs.get("rpath", "./thalassa/") - logger.info("save stats\n") - ids = stations.id.values - stats = pd.DataFrame(sts) stats.index.name = "id" stats = stats.to_xarray().assign_coords({"id": ids}) # stats - rpath = kwargs.get("rpath", "./thalassa/") - output_path = os.path.join(rpath, "stats.nc") stats.to_netcdf(output_path) logger.info(f"..done with stats file\n") -def to_thalassa(folder, freq=None, **kwargs): +def to_thalassa(folder, **kwargs): # Retrieve data tag = kwargs.get("tag", "schism") + leads = kwargs.get("leads", None) rpath = kwargs.get("rpath", "./thalassa/") gglobal = kwargs.get("gglobal", False) to2d = kwargs.get("to2d", None) @@ -195,6 +190,10 @@ def to_thalassa(folder, freq=None, **kwargs): rvars_ = rvars + [x_var, y_var, tes_var] out = out[rvars_] + # Add max elevation variable + out = out.assign(max_elev=out[ename].max("time")) + rvars = rvars + ["max_elev"] + # set enconding encoding = {} for var in rvars: @@ -216,15 +215,16 @@ def to_thalassa(folder, freq=None, **kwargs): logger.info("save time series depending on lead time\n") - skill_path = os.path.join(rpath, "skill") - total_hours = pd.to_timedelta(b.time_frame) / pd.Timedelta(hours=1) - dt = total_hours / freq + if leads: + skill_path = os.path.join(rpath, "skill") + total_hours = pd.to_timedelta(b.time_frame) / pd.Timedelta(hours=1) + dt = total_hours / leads - if dt % 1 == 0.0: - save_leads(stations, st, b.start_date, dt, freq, rpath=skill_path) + if dt % 1 == 0.0: + save_leads(stations, st, b.start_date, dt, leads, rpath=skill_path) - else: - logger.warning("freq not correct, aborting\n") + else: + logger.warning("leads value not correct, aborting\n") # save sim data @@ -246,6 +246,20 @@ def to_thalassa(folder, freq=None, **kwargs): vdata.to_netcdf(output_path) logger.info(f"..done with {filename} file\n") + +def to_obs(folder, **kwargs): + tag = kwargs.get("tag", "schism") + rpath = kwargs.get("rpath", "./thalassa/") + + json_file = os.path.join(folder, "{}_model.json".format(tag)) + b = pyposeidon.model.read(json_file) + + stations = gp.GeoDataFrame.from_file(b.obs) + + # assign unique id + if "id" not in stations.columns: + stations["id"] = [f"IOC-{x}" for x in stations.ioc_code] + # get observations last timestamp obs_files_path = os.path.join(rpath, "obs/") if not os.path.exists(obs_files_path): @@ -261,10 +275,4 @@ def to_thalassa(folder, freq=None, **kwargs): gather_obs_data(stations, start_date, b.end_date, rpath=obs_files_path) - # compute stats - logger.info("compute statistics") - sts = compute_stats(st, rpath=obs_files_path) - - save_stats(sts, stations, **kwargs) - - logger.info(f"post processing complete for folder {folder}\n") + return diff --git a/pyposeidon/utils/statistics.py b/pyposeidon/utils/statistics.py index 2d9afe43..8d4a0146 100755 --- a/pyposeidon/utils/statistics.py +++ b/pyposeidon/utils/statistics.py @@ -10,7 +10,6 @@ def get_stats(sim_, obs_): if obs_.dropna().empty: - logger.warning("Observation data not available for this station") stats = pd.Series( index={ "Mean Absolute Error",