Skip to content

Commit

Permalink
post: streamline the processes
Browse files Browse the repository at this point in the history
split in three separate functions
  • Loading branch information
brey committed Sep 28, 2023
1 parent 6e5816f commit e65b211
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 34 deletions.
37 changes: 36 additions & 1 deletion pyposeidon/utils/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,36 @@
import numpy as np
import xarray as xr
from datetime import datetime
import itertools
from typing import Iterable
from typing import Iterator


def grouper(
iterable: Iterable[_T],
n: int,
*,
incomplete: str = "fill",
fillvalue: Union[_U, None] = None,
) -> Iterator[Tuple[Union[_T, _U], ...]]:
"""Collect data into non-overlapping fixed-length chunks or blocks"""
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
args = [iter(iterable)] * n
if incomplete == "fill":
return itertools.zip_longest(*args, fillvalue=fillvalue)
if incomplete == "strict":
return zip(*args, strict=True) # type: ignore[call-overload]
if incomplete == "ignore":
return zip(*args)
else:
raise ValueError("Expected fill, strict, or ignore")


def merge_datasets(datasets: List[xr.Dataset], size: int = 5) -> List[xr.Dataset]:
datasets = [xr.merge(g for g in group if g) for group in grouper(datasets, size)]
return datasets


def get_bogus(temp, idfs, ntime):
Expand Down Expand Up @@ -52,7 +82,12 @@ def get_obs_data(stations: str | gp.GeoDataFrame, start_time=None, end_time=None
xdfs = []
for idfs in sn.ioc_code:
xdfs.append(get_bogus(temp, idfs, ntime))
# merge

# in order to keep memory consumption low, let's group the datasets
# and merge them in batches
while len(xdfs) > 5:
xdfs = merge_datasets(xdfs)
# Do the final merging
sns = xr.merge(xdfs)

# set correct dims for metadata
Expand Down
72 changes: 40 additions & 32 deletions pyposeidon/utils/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def get_encoding(ename):
}


def save_leads(stations, st, start_date, dt, freq, rpath="./skill/"):
def save_leads(stations, st, start_date, dt, leads, rpath="./skill/"):
## save results in lead chunks

for l in range(freq):
for l in range(leads):
from_date = start_date + pd.to_timedelta("{}H".format(l * dt)) + pd.to_timedelta("1S")
to_date = start_date + pd.to_timedelta("{}H".format((l + 1) * dt))
h = st.sel(time=slice(from_date, to_date), drop=True)
Expand Down Expand Up @@ -106,15 +106,16 @@ def gather_obs_data(stations, start_time, end_time, rpath="./thalassa/obs/"):
return


def compute_stats(st, rpath="./thalassa/obs/"):
def to_stats(st, rpath="./thalassa/", opath="./thalassa/obs/"):
logger.info("compute general statistics for station points\n")

ids = st.id.values

sts = []
for inode in tqdm(st.id.values):
isim = st.where(st.id == inode).dropna(dim="id")
sim = isim.elev.to_dataframe().droplevel(1)
for inode in tqdm(ids):
sim = st.sel(id=inode).elev_sim.to_dataframe().drop("id", axis=1)

filename = os.path.join(rpath, f"{inode}.parquet")
filename = os.path.join(opath, f"{inode}.parquet")
obs = pd.read_parquet(filename, engine="fastparquet")
obs_ = obs.dropna(axis=1, how="all") # drop all nan columns

Expand All @@ -124,35 +125,29 @@ def compute_stats(st, rpath="./thalassa/obs/"):
cols = [x for x in obs_.columns if x in valid_sensors]
obs = obs_[[cols[0]]] # just choose one for now

if obs.dropna().empty:
logger.warning(f"Observation data not available for {inode} station")

stable = get_stats(sim, obs) # Do general statitics

sts.append(stable)

return sts


def save_stats(sts, stations, **kwargs):
rpath = kwargs.get("rpath", "./thalassa/")

logger.info("save stats\n")

ids = stations.id.values

stats = pd.DataFrame(sts)
stats.index.name = "id"
stats = stats.to_xarray().assign_coords({"id": ids}) # stats

rpath = kwargs.get("rpath", "./thalassa/")

output_path = os.path.join(rpath, "stats.nc")

stats.to_netcdf(output_path)
logger.info(f"..done with stats file\n")


def to_thalassa(folder, freq=None, **kwargs):
def to_thalassa(folder, **kwargs):
# Retrieve data
tag = kwargs.get("tag", "schism")
leads = kwargs.get("leads", None)
rpath = kwargs.get("rpath", "./thalassa/")
gglobal = kwargs.get("gglobal", False)
to2d = kwargs.get("to2d", None)
Expand Down Expand Up @@ -195,6 +190,10 @@ def to_thalassa(folder, freq=None, **kwargs):
rvars_ = rvars + [x_var, y_var, tes_var]
out = out[rvars_]

# Add max elevation variable
out = out.assign(max_elev=out[ename].max("time"))
rvars = rvars + ["max_elev"]

# set enconding
encoding = {}
for var in rvars:
Expand All @@ -216,15 +215,16 @@ def to_thalassa(folder, freq=None, **kwargs):

logger.info("save time series depending on lead time\n")

skill_path = os.path.join(rpath, "skill")
total_hours = pd.to_timedelta(b.time_frame) / pd.Timedelta(hours=1)
dt = total_hours / freq
if leads:
skill_path = os.path.join(rpath, "skill")
total_hours = pd.to_timedelta(b.time_frame) / pd.Timedelta(hours=1)
dt = total_hours / leads

if dt % 1 == 0.0:
save_leads(stations, st, b.start_date, dt, freq, rpath=skill_path)
if dt % 1 == 0.0:
save_leads(stations, st, b.start_date, dt, leads, rpath=skill_path)

else:
logger.warning("freq not correct, aborting\n")
else:
logger.warning("leads value not correct, aborting\n")

# save sim data

Expand All @@ -246,6 +246,20 @@ def to_thalassa(folder, freq=None, **kwargs):
vdata.to_netcdf(output_path)
logger.info(f"..done with {filename} file\n")


def to_obs(folder, **kwargs):
tag = kwargs.get("tag", "schism")
rpath = kwargs.get("rpath", "./thalassa/")

json_file = os.path.join(folder, "{}_model.json".format(tag))
b = pyposeidon.model.read(json_file)

stations = gp.GeoDataFrame.from_file(b.obs)

# assign unique id
if "id" not in stations.columns:
stations["id"] = [f"IOC-{x}" for x in stations.ioc_code]

# get observations last timestamp
obs_files_path = os.path.join(rpath, "obs/")
if not os.path.exists(obs_files_path):
Expand All @@ -261,10 +275,4 @@ def to_thalassa(folder, freq=None, **kwargs):

gather_obs_data(stations, start_date, b.end_date, rpath=obs_files_path)

# compute stats
logger.info("compute statistics")
sts = compute_stats(st, rpath=obs_files_path)

save_stats(sts, stations, **kwargs)

logger.info(f"post processing complete for folder {folder}\n")
return
1 change: 0 additions & 1 deletion pyposeidon/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

def get_stats(sim_, obs_):
if obs_.dropna().empty:
logger.warning("Observation data not available for this station")
stats = pd.Series(
index={
"Mean Absolute Error",
Expand Down

0 comments on commit e65b211

Please sign in to comment.