diff --git a/pyposeidon/boundary.py b/pyposeidon/boundary.py index 4487767a..2b337568 100644 --- a/pyposeidon/boundary.py +++ b/pyposeidon/boundary.py @@ -302,13 +302,17 @@ def tag(geometry, coasts, cbuffer, blevels): ww = gp.GeoDataFrame(geometry=cs) - try: - gw = gp.GeoDataFrame( - geometry=list(ww.buffer(0).unary_union) - ) # merge the polygons that are split (around -180/180) - except: + if ww.empty: gw = gp.GeoDataFrame(geometry=list(ww.values)) + else: + try: + gw = gp.GeoDataFrame( + geometry=list(ww.buffer(0).unary_union) + ) # merge the polygons that are split (around -180/180) + except: + gw = gp.GeoDataFrame(geometry=list(ww.values)) + if wc.geom_type.all() != "Polygon": gw = gp.GeoDataFrame(geometry=gw.boundary.values) diff --git a/pyposeidon/dem.py b/pyposeidon/dem.py index bd712cfb..fcd47595 100644 --- a/pyposeidon/dem.py +++ b/pyposeidon/dem.py @@ -87,7 +87,7 @@ def __init__(self, dem_source: str, **kwargs): self.adjust(coastline, **kwargs) def adjust(self, coastline, **kwargs): - self.Dataset, check = fix(self.Dataset, coastline, **kwargs) + self.Dataset, check, flag = fix(self.Dataset, coastline, **kwargs) if not check: logger.warning("Adjusting dem failed, keeping original values\n") diff --git a/pyposeidon/schism.py b/pyposeidon/schism.py index c3418608..427da1a3 100644 --- a/pyposeidon/schism.py +++ b/pyposeidon/schism.py @@ -566,7 +566,11 @@ def output(self, **kwargs): logger.info("Keeping bathymetry from hgrid.gr3 ..\n") copyfile(self.mesh_file, os.path.join(path, "hgrid.gr3")) # copy original grid file - copyfile(os.path.join(path, "hgrid.gr3"), os.path.join(path, "hgrid.ll")) + src = os.path.join(path, "hgrid.gr3") + dst = os.path.join(path, "hgrid.ll") + if os.path.lexists(dst): + os.remove(dst) + os.symlink(src, dst) # manning file manfile = os.path.join(path, "manning.gr3") @@ -677,6 +681,11 @@ def run(self, **kwargs): proc = tools.execute_schism_mpirun_script(cwd=calc_dir) + if proc.returncode == 0: + # --------------------------------------------------------------------- + logger.info("model finished successfully\n") + # --------------------------------------------------------------------- + def save(self, **kwargs): path = get_value(self, kwargs, "rpath", "./schism/") @@ -705,7 +714,13 @@ def save(self, **kwargs): dic.update({"meteo": [x.attrs for x in meteo.Dataset]}) coastlines = self.__dict__.get("coastlines", None) - dic.update({"coastlines": coastlines}) + if coastlines is not None: + # Save the path to the serialized coastlines - #130 + coastlines_database = os.path.join(path, "coastlines.json") + coastlines.to_file(coastlines_database) + dic.update({"coastlines": coastlines_database}) + else: + dic.update({"coastlines": None}) dic["version"] = pyposeidon.__version__ @@ -1297,10 +1312,16 @@ def results(self, **kwargs): path = get_value(self, kwargs, "rpath", "./schism/") logger.info("get combined 2D netcdf files \n") + # check for new IO output hfiles = glob.glob(os.path.join(path, "outputs/out2d_*.nc")) hfiles.sort() + # check for old IO output + ofiles = glob.glob(os.path.join(path, "outputs/schout_*_*.nc")) + + if not (ofiles or hfiles): + logger.warning("no output netcdf files, moving on") - if hfiles: + elif hfiles: x2d = xr.open_mfdataset(hfiles, data_vars="minimal") # set timestamp @@ -1335,7 +1356,7 @@ def results(self, **kwargs): # save 2D variables to file x2d.to_netcdf(os.path.join(path, "outputs/schout_1.nc")) - else: + elif ofiles: if len(self.misc) == 0: logger.info("retrieving index references ... \n") self.global2local(**kwargs) @@ -1704,6 +1725,9 @@ def results(self, **kwargs): xc.to_netcdf(os.path.join(path, f"outputs/schout_{val}.nc")) + else: + raise Exception("This should never happen") + logger.info("done with output netCDF files \n") def set_obs(self, **kwargs): @@ -1733,6 +1757,9 @@ def set_obs(self, **kwargs): ##### normalize to be used inside pyposeidon tgn = normalize_column_names(tg.copy()) + ##### make sure lat/lon are floats + tgn = tgn.astype({"latitude": float, "longitude": float}) + coastal_monitoring = get_value(self, kwargs, "coastal_monitoring", False) flags = get_value(self, kwargs, "station_flags", [1] + [0] * 8) @@ -1833,7 +1860,7 @@ def set_obs(self, **kwargs): self.params["SCHOUT"]["nspool_sta"] = nspool_sta self.params.write(os.path.join(path, "param.nml"), force=True) - self.stations = stations + self.stations_mesh_id = stations logger.info("write out stations.in file \n") @@ -1853,25 +1880,25 @@ def get_station_sim_data(self, **kwargs): try: # get the station flags flags = pd.read_csv(os.path.join(path, "station.in"), header=None, nrows=1, delim_whitespace=True).T - flags.columns = ["flag"] - flags["variable"] = [ - "elev", - "air_pressure", - "windx", - "windy", - "T", - "S", - "u", - "v", - "w", - ] - - vals = flags[flags.values == 1] # get the active ones - except OSError as e: - if e.errno == errno.EEXIST: - logger.error("no station.in file present") + except FileNotFoundError: + logger.error("no station.in file present") return + flags.columns = ["flag"] + flags["variable"] = [ + "elev", + "air_pressure", + "windx", + "windy", + "T", + "S", + "u", + "v", + "w", + ] + + vals = flags[flags.values == 1] # get the active ones + dstamp = kwargs.get("dstamp", self.rdate) dfs = [] diff --git a/pyposeidon/utils/cast.py b/pyposeidon/utils/cast.py index b0af1522..f96350d6 100644 --- a/pyposeidon/utils/cast.py +++ b/pyposeidon/utils/cast.py @@ -285,6 +285,10 @@ def run(self, **kwargs): self.rdate = self.model.rdate ppath = self.ppath + # ppath = pathlib.Path(ppath).resolve() + # ppath = str(ppath) + ppath = os.path.realpath(ppath) + # control if not isinstance(self.rdate, pd.Timestamp): self.rdate = pd.to_datetime(self.rdate) @@ -296,6 +300,10 @@ def run(self, **kwargs): # create the new folder/run path rpath = self.cpath + # rpath = pathlib.Path(rpath).resolve() + # rpath = str(rpath) + rpath = os.path.realpath(rpath) + if not os.path.exists(rpath): os.makedirs(rpath) @@ -309,11 +317,16 @@ def run(self, **kwargs): info = data.to_dict(orient="records")[0] try: - args = set(kwargs.keys()).intersection(info.keys()) # modify dic with kwargs + args = info.keys() & kwargs.keys() # modify dic with kwargs for attr in list(args): - info[attr] = kwargs[attr] - except: - pass + if isinstance(info[attr], dict): + info[attr].update(kwargs[attr]) + else: + info[attr] = kwargs[attr] + setattr(self, attr, info[attr]) + except Exception as e: + logger.exception("problem with kwargs integration\n") + raise e # add optional additional kwargs for attr in kwargs.keys(): @@ -333,30 +346,14 @@ def run(self, **kwargs): m = pm.set(**info) - # Mesh - gfile = glob.glob(os.path.join(ppath, "hgrid.gr3")) - if gfile: - info["mesh_file"] = gfile[0] - self.mesh_file = gfile[0] - info["mesh_generator"] = None - self.mesh_generator = None - - m.mesh = pmesh.set(type="tri2d", **info) - - # get lat/lon from file - if hasattr(self, "mesh_file"): - info.update({"lon_min": m.mesh.Dataset.SCHISM_hgrid_node_x.values.min()}) - info.update({"lon_max": m.mesh.Dataset.SCHISM_hgrid_node_x.values.max()}) - info.update({"lat_min": m.mesh.Dataset.SCHISM_hgrid_node_y.values.min()}) - info.update({"lat_max": m.mesh.Dataset.SCHISM_hgrid_node_y.values.max()}) - # copy/link necessary files logger.debug("Copy necessary + station files") copy_files(rpath=rpath, ppath=ppath, filenames=self.files + self.station_files) - logger.debug("Copy model files") if copy: + logger.debug("Copy model files") copy_files(rpath=rpath, ppath=ppath, filenames=self.model_files) else: + logger.debug("Symlink model files") symlink_files(rpath=rpath, ppath=ppath, filenames=self.model_files) logger.debug(".. done") @@ -391,18 +388,13 @@ def run(self, **kwargs): else: logger.info("Symlinking`: %s -> %s", inresfile, outresfile) try: - os.symlink( - pathlib.Path(os.path.join(ppath, inresfile)).resolve(strict=True), os.path.join(rpath, outresfile) - ) + os.symlink(inresfile, outresfile) except OSError as e: if e.errno == errno.EEXIST: logger.warning("Restart link present\n") logger.warning("overwriting\n") - os.remove(os.path.join(rpath, outresfile)) - os.symlink( - pathlib.Path(os.path.join(ppath, inresfile)).resolve(strict=True), - os.path.join(rpath, outresfile), - ) + os.remove(outresfile) + os.symlink(inresfile, outresfile) else: raise e # get new meteo diff --git a/pyposeidon/utils/data.py b/pyposeidon/utils/data.py index d08ca1a5..ddde2eca 100644 --- a/pyposeidon/utils/data.py +++ b/pyposeidon/utils/data.py @@ -192,11 +192,15 @@ def __init__(self, **kwargs): datai.append(xdat) # append to list - merge = kwargs.get("merge", True) + if not any(datai): + logger.warning("no output netcdf files.") + self.Dataset = None + else: + merge = kwargs.get("merge", True) - if merge: - datai = flat_list(datai) - self.Dataset = xr.open_mfdataset(datai, combine="by_coords", data_vars="minimal") + if merge: + datai = flat_list(datai) + self.Dataset = xr.open_mfdataset(datai, combine="by_coords", data_vars="minimal") - else: - self.Dataset = [xr.open_mfdataset(x, combine="by_coords", data_vars="minimal") for x in datai] + else: + self.Dataset = [xr.open_mfdataset(x, combine="by_coords", data_vars="minimal") for x in datai] diff --git a/pyposeidon/utils/fix.py b/pyposeidon/utils/fix.py index 5104ade4..792f38b4 100644 --- a/pyposeidon/utils/fix.py +++ b/pyposeidon/utils/fix.py @@ -16,6 +16,7 @@ import xarray as xr import sys import os +from pyposeidon.utils.coastfix import simplify # logging setup import logging @@ -28,12 +29,16 @@ def fix(dem, coastline, **kwargs): logger.info("adjust dem\n") # --------------------------------------------------------------------- + ifunction = kwargs.get("resample_function", "nearest") + # define coastline try: shp = gp.GeoDataFrame.from_file(coastline) except: shp = gp.GeoDataFrame(coastline) + shp = simplify(shp) + if "ival" in dem.data_vars: xp = dem.ilons.values yp = dem.ilats.values @@ -125,7 +130,7 @@ def fix(dem, coastline, **kwargs): else: dem = dem.assign(adjusted=dem.elevation) - return dem + return dem, True if "ival" in dem.data_vars: df = pd.DataFrame( @@ -194,7 +199,7 @@ def fix(dem, coastline, **kwargs): xw = pw.longitude.values yw = pw.latitude.values - bw = resample(dem, xw, yw, var="elevation", wet=True, flag=flag) + bw = resample(dem, xw, yw, var="elevation", wet=True, flag=flag, function=ifunction) df.loc[pw.index, "elevation"] = bw # replace in original dataset @@ -213,7 +218,7 @@ def fix(dem, coastline, **kwargs): xl = pl.longitude.values yl = pl.latitude.values - bd = resample(dem, xl, yl, var="elevation", wet=False, flag=flag) + bd = resample(dem, xl, yl, var="elevation", wet=False, flag=flag, function=ifunction) df.loc[pl.index, "elevation"] = bd # replace in original dataset @@ -248,36 +253,49 @@ def fix(dem, coastline, **kwargs): nanp = check1(cdem, water) - logger.info("Nan value for {} points".format(len(nanp))) + if len(nanp) == 0: + valid = True + else: + valid = False + + logger.info("Nan value for {} sea points".format(len(nanp))) - on_coast = check2(cdem, shp) + check = kwargs.get("check", False) - logger.info("{} points on the boundary, setting to zero".format(len(on_coast))) + if check: + on_coast = check2(cdem, shp) - if "fval" in cdem.data_vars: - tt = len(cdem.fval.shape) + logger.info("{} points on the boundary, setting to zero".format(len(on_coast))) - if tt == 1: - cdem.fval[on_coast] = 0.0 + if "fval" in cdem.data_vars: + tt = len(cdem.fval.shape) - elif tt == 2: - bmask = np.zeros(cdem.fval.shape, dtype=bool) # create mask + if tt == 1: + cdem.fval[on_coast] = 0.0 + + elif tt == 2: + bmask = np.zeros(cdem.fval.shape, dtype=bool) # create mask + for idx, [i, j] in enumerate(on_coast): + bmask[i, j] = True + cdem.fval.values[bmask] = 0.0 # set value + + elif "adjusted" in cdem.data_vars: + bmask = np.zeros(cdem.adjusted.shape, dtype=bool) # create mask for idx, [i, j] in enumerate(on_coast): bmask[i, j] = True - cdem.fval.values[bmask] = 0.0 # set value + cdem.adjusted.values[bmask] = 0.0 # set value - elif "adjusted" in cdem.data_vars: - bmask = np.zeros(cdem.adjusted.shape, dtype=bool) # create mask - for idx, [i, j] in enumerate(on_coast): - bmask[i, j] = True - cdem.adjusted.values[bmask] = 0.0 # set value + logger.info("setting land points with nan values to zero") + cdem["adjusted"] = cdem.adjusted.fillna(0.0) # for land points if any (lakes, etc.) - if len(nanp) == 0: - valid = True else: - valid = False + logger.info("setting land points with nan values to zero") + if "ival" in cdem.data_vars: + cdem["fval"] = cdem.fval.fillna(0.0) + elif "adjusted" in cdem.data_vars: + cdem["adjusted"] = cdem.adjusted.fillna(0.0) - return cdem, valid + return cdem, valid, flag def check1(dataset, water): @@ -405,7 +423,7 @@ def check2(dataset, coastline): return bps -def resample(dem, xw, yw, var=None, wet=True, flag=None): +def resample(dem, xw, yw, var=None, wet=True, flag=None, function="nearest"): # Define points with positive bathymetry x, y = np.meshgrid(dem.longitude, dem.latitude) @@ -439,6 +457,80 @@ def resample(dem, xw, yw, var=None, wet=True, flag=None): orig = pyresample.geometry.SwathDefinition(lons=mx, lats=my) # original bathymetry points targ = pyresample.geometry.SwathDefinition(lons=gx, lats=yw) # wet points - bw = pyresample.kd_tree.resample_nearest(orig, mdem, targ, radius_of_influence=100000, fill_value=np.nan) + if function == "nearest": + bw = pyresample.kd_tree.resample_nearest(orig, mdem, targ, radius_of_influence=100000, fill_value=np.nan) + + elif function == "gauss": + bw = pyresample.kd_tree.resample_gauss( + orig, mdem, targ, radius_of_influence=500000, neighbours=10, sigmas=250000, fill_value=np.nan + ) return bw + + +def dem_range(data, lon_min, lon_max, lat_min, lat_max): + dlon0 = round(data.longitude.data.min()) + dlon1 = round(data.longitude.data.max()) + + # recenter the window + if dlon1 - dlon0 == 360.0: + lon0 = lon_min + 360.0 if lon_min < data.longitude.min() else lon_min + lon1 = lon_max + 360.0 if lon_max < data.longitude.min() else lon_max + + lon0 = lon0 - 360.0 if lon0 > data.longitude.max() else lon0 + lon1 = lon1 - 360.0 if lon1 > data.longitude.max() else lon1 + + else: + lon0 = lon_min + lon1 = lon_max + + if (lon_min < data.longitude.min()) or (lon_max > data.longitude.max()): + print("Lon must be within {} and {}".format(data.longitude.min().values, data.longitude.max().values)) + print("compensating if global dataset available") + + if (lat_min < data.latitude.min()) or (lat_max > data.latitude.max()): + print("Lat is within {} and {}".format(data.latitude.min().values, data.latitude.max().values)) + + # get idx + if lon_max - lon_min == dlon1 - dlon0: + i0 = 0 if lon_min == dlon0 else int(data.longitude.shape[0] / 2) + 2 # compensate for below + i1 = data.longitude.shape[0] if lon_max == dlon1 else -int(data.longitude.shape[0] / 2) - 2 + else: + i0 = np.abs(data.longitude.data - lon0).argmin() + i1 = np.abs(data.longitude.data - lon1).argmin() + + j0 = np.abs(data.latitude.data - lat_min).argmin() + j1 = np.abs(data.latitude.data - lat_max).argmin() + + # expand the window a little bit + lon_0 = max(0, i0 - 2) + lon_1 = min(data.longitude.size, i1 + 2) + + lat_0 = max(0, j0 - 2) + lat_1 = min(data.latitude.size, j1 + 2) + + # descenting lats + if j0 > j1: + j0, j1 = j1, j0 + lat_0 = max(0, j0 - 1) + lat_1 = min(data.latitude.size, j1 + 3) + + if i0 > i1: + p1 = data.elevation.isel(longitude=slice(lon_0, data.longitude.size), latitude=slice(lat_0, lat_1)) + + p1 = p1.assign_coords({"longitude": p1.longitude.values - 360.0}) + + p2 = data.elevation.isel(longitude=slice(0, lon_1), latitude=slice(lat_0, lat_1)) + + dem = xr.concat([p1, p2], dim="longitude") + + else: + dem = data.elevation.isel(longitude=slice(lon_0, lon_1), latitude=slice(lat_0, lat_1)) + + if np.abs(np.mean(dem.longitude) - np.mean([lon_min, lon_max])) > 170.0: + c = np.sign(np.mean([lon_min, lon_max])) + dem["longitude"] = dem["longitude"] + c * 360.0 + + dem_data = xr.merge([dem]) + + return dem_data diff --git a/pyposeidon/utils/obs.py b/pyposeidon/utils/obs.py index 87a52f70..2f9e27f9 100644 --- a/pyposeidon/utils/obs.py +++ b/pyposeidon/utils/obs.py @@ -6,7 +6,49 @@ from searvey import ioc import pandas as pd import geopandas as gp +import numpy as np +import xarray as xr from datetime import datetime +import itertools +from typing import Iterable +from typing import Iterator + + +def grouper( + iterable: Iterable[_T], + n: int, + *, + incomplete: str = "fill", + fillvalue: Union[_U, None] = None, +) -> Iterator[Tuple[Union[_T, _U], ...]]: + """Collect data into non-overlapping fixed-length chunks or blocks""" + # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx + # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError + # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF + args = [iter(iterable)] * n + if incomplete == "fill": + return itertools.zip_longest(*args, fillvalue=fillvalue) + if incomplete == "strict": + return zip(*args, strict=True) # type: ignore[call-overload] + if incomplete == "ignore": + return zip(*args) + else: + raise ValueError("Expected fill, strict, or ignore") + + +def merge_datasets(datasets: List[xr.Dataset], size: int = 5) -> List[xr.Dataset]: + datasets = [xr.merge(g for g in group if g) for group in grouper(datasets, size)] + return datasets + + +def get_bogus(temp, idfs, ntime): + return xr.Dataset( + {"nodata": (["ioc_code", "time"], temp)}, + coords={ + "ioc_code": (["ioc_code"], [idfs]), + "time": ntime, + }, + ) def get_obs_data(stations: str | gp.GeoDataFrame, start_time=None, end_time=None, period=None, **kwargs): @@ -27,4 +69,38 @@ def get_obs_data(stations: str | gp.GeoDataFrame, start_time=None, end_time=None period=period, ) + if "id" not in stations.columns: + stations["id"] = [f"IOC-{x}" for x in stations.ioc_code] + + s2 = stations.loc[stations.ioc_code.isin(data.ioc_code.values)].id.values + + # create bogus data when no data + sn = stations.loc[~stations.ioc_code.isin(s2)] + sa = sn[["ioc_code", "lat", "lon", "country", "location"]].to_xarray() + ntime = data.isel(ioc_code=0).time.data + temp = np.array([np.NaN] * ntime.shape[0])[np.newaxis, :] # nans for the time frame + xdfs = [] + for idfs in sn.ioc_code: + xdfs.append(get_bogus(temp, idfs, ntime)) + + # in order to keep memory consumption low, let's group the datasets + # and merge them in batches + while len(xdfs) > 5: + xdfs = merge_datasets(xdfs) + # Do the final merging + sns = xr.merge(xdfs) + + # set correct dims for metadata + sa = sa.set_coords("ioc_code").swap_dims({"index": "ioc_code"}).reset_coords("index").drop_vars("index") + + # merge nodata with metadata + bg = xr.merge([sns, sa]) + + # merge bogus with searvey data + data = xr.merge([data, bg]) + + sts = stations.sort_values(by=["ioc_code"]) # make sure the indexing works + + data = data.assign_coords({"id": ("ioc_code", sts.id)}).swap_dims({"ioc_code": "id"}).reset_coords("ioc_code") + return data diff --git a/pyposeidon/utils/post.py b/pyposeidon/utils/post.py index a560be9d..8457bd0d 100644 --- a/pyposeidon/utils/post.py +++ b/pyposeidon/utils/post.py @@ -5,6 +5,7 @@ import xarray as xr import numpy as np from tqdm.auto import tqdm +import glob import pyposeidon from pyposeidon.utils import data @@ -14,6 +15,30 @@ # from pyposeidon.utils.detide import get_ss +valid_sensors = [ + "rad", + "prs", + "enc", + "pr1", + "PR2", + "pr2", + "ra2", + "bwl", + "wls", + "aqu", + "ras", + "pwl", + "bub", + "enb", + "atm", + "flt", + "ecs", + "stp", + "prte", + "prt", + "ra3", +] + import logging logger = logging.getLogger(__name__) @@ -25,76 +50,104 @@ def get_encoding(ename): } -def get_bogus(temp, idfs, ntime): - return xr.DataArray( - temp, - coords={ - "node": (["node"], [idfs]), - "time": ntime, - }, - ) +def save_leads(stations, st, start_date, dt, leads, rpath="./skill/"): + ## save results in lead chunks + + for l in range(leads): + from_date = start_date + pd.to_timedelta("{}H".format(l * dt)) + pd.to_timedelta("1S") + to_date = start_date + pd.to_timedelta("{}H".format((l + 1) * dt)) + h = st.sel(time=slice(from_date, to_date), drop=True) + h = h.rename({"elev": "elev_fct", "time": "ftime"}) + leadfile = os.path.join(rpath, f"lead{l}") + if not os.path.exists(leadfile): + h.to_zarr(store=leadfile, mode="a") + else: + h.to_zarr(store=leadfile, mode="a", append_dim="ftime") + + return -def compute_obs(stations, start_time, end_time): - logger.info("Retrieve observation data for station points\n") +def gather_obs_data(stations, start_time, end_time, rpath="./thalassa/obs/"): + logger.info(f"Retrieve observation data for station points from {start_time} to {end_time}\n") odata = get_obs_data(stations=stations, start_time=start_time, end_time=end_time) - dfs = [x for x in stations.ioc_code if x not in odata.ioc_code] - idx = stations[stations["ioc_code"].isin(dfs)].index + logger.info("Save observation data for station locations\n") + + for inode in tqdm(odata.id.values): + oi = odata.sel(id=inode) + + var = [k for k, v in oi.data_vars.items() if "time" in v.dims] + + df = oi[var].to_dataframe().drop("id", axis=1) + df_ = df.dropna(axis=1, how="all") # drop all nan columns + + file_path = os.path.join(rpath, f"{inode}.parquet") - logger.info("Normalize observation data for station points\n") + if os.path.isfile(file_path): + obs = pd.read_parquet(file_path, engine="fastparquet") + # make sure there is output + if df_.empty: + cols = [x for x in var if x in obs.columns] + df = df[cols] + df.to_parquet(file_path, engine="fastparquet", append=True) + else: + df = df_ + out = pd.concat([obs, df]).dropna(how="all") # merge + out.to_parquet(file_path, engine="fastparquet") + else: + # make sure there is output + if df_.empty: + df = df[[df.columns[0]]] + else: + df = df_ - od = [] - for inode in tqdm(odata.ioc_code.values): - oi = odata.sel(ioc_code=inode) + df.to_parquet(file_path, engine="fastparquet") - for var in oi.data_vars: - if oi[var].isnull().all().values == True: - oi = oi.drop(var) + return - var = [k for k, v in oi.data_vars.items() if v.dims == ("time",)][0] - obs = oi[var].to_dataframe().drop(["ioc_code"], axis=1) # Get observational data +def to_stats(st, rpath="./thalassa/", opath="./thalassa/obs/"): + logger.info("compute general statistics for station points\n") - # de-tide obs - # if not obs[var].dropna().empty | (obs[var].dropna() == obs[var].dropna()[0]).all(): - # obs = get_ss(obs, oi.lat.values) - # oi[var].values = obs.elev.values + ids = st.id.values - od.append(oi[var].rename("elev_obs")) + sts = [] + for inode in tqdm(ids): + sim = st.sel(id=inode).elev_sim.to_dataframe().drop("id", axis=1) - ods = xr.concat(od, dim="ioc_code").rename({"ioc_code": "node"}) # obs + filename = os.path.join(opath, f"{inode}.parquet") + obs = pd.read_parquet(filename, engine="fastparquet") + obs_ = obs.dropna(axis=1, how="all") # drop all nan columns - ## add bogus observations in order to keep array shape + if obs_.empty: + obs = obs[[obs.columns[0]]] + else: + cols = [x for x in obs_.columns if x in valid_sensors] + obs = obs_[[cols[0]]] # just choose one for now - ntime = odata.isel(ioc_code=0).time.data - temp = np.array([np.NaN] * ntime.shape[0])[np.newaxis, :] - xdfs = [] - for idfs in dfs: - xdfs.append(get_bogus(temp, idfs, ntime)) + if obs.dropna().empty: + logger.warning(f"Observation data not available for {inode} station") - ods = xr.concat([ods] + xdfs, dim="node") + stable = get_stats(sim, obs) # Do general statitics - return ods + sts.append(stable) + logger.info("save stats\n") -def save_ods(ods, **kwargs): - rpath = kwargs.get("rpath", "./thalassa/") + stats = pd.DataFrame(sts) + stats.index.name = "id" + stats = stats.to_xarray().assign_coords({"id": ids}) # stats - # obs data - logger.info("saving observations data\n") - obs_file = os.path.join(rpath, f"searvey") - if not os.path.exists(obs_file): - ods.to_zarr(store=obs_file, mode="a") - else: - ods.to_zarr(store=obs_file, mode="a", append_dim="time") + output_path = os.path.join(rpath, "stats.nc") - return + stats.to_netcdf(output_path) + logger.info(f"..done with stats file\n") -def to_thalassa(folder, freq=None, **kwargs): +def to_thalassa(folder, **kwargs): # Retrieve data tag = kwargs.get("tag", "schism") + leads = kwargs.get("leads", None) rpath = kwargs.get("rpath", "./thalassa/") gglobal = kwargs.get("gglobal", False) to2d = kwargs.get("to2d", None) @@ -113,7 +166,8 @@ def to_thalassa(folder, freq=None, **kwargs): os.makedirs(rpath) # Get simulation data - b = pyposeidon.model.read(folder + "/{}_model.json".format(tag)) + json_file = os.path.join(folder, "{}_model.json".format(tag)) + b = pyposeidon.model.read(json_file) b.get_output_data() b.get_station_sim_data() st = b.station_sim_data @@ -126,22 +180,22 @@ def to_thalassa(folder, freq=None, **kwargs): logger.info("converting to 2D\n") # Convert to 2D [xn, yn, tri3n] = np.load(to2d, allow_pickle=True) - sv = [] - for var in rvars: - isv = to_2d(out, var=var, mesh=[xn, yn, tri3n]) # elevation - sv.append(isv) - out = xr.merge(sv) + out = to_2d(out, data_vars=rvars, mesh=[xn, yn, tri3n]) # elevation else: rvars_ = rvars + [x_var, y_var, tes_var] out = out[rvars_] + # Add max elevation variable + out = out.assign(max_elev=out[ename].max("time")) + rvars = rvars + ["max_elev"] + # set enconding encoding = {} for var in rvars: encoding.update(get_encoding(var)) - logger.info("Saving combined netcdf file for folder {}\n".format(folder)) + logger.info("saving combined netcdf file for folder {}\n".format(folder)) output_file = os.path.join( rpath, f"{b.start_date.strftime('%Y%m%d%H')}.nc", @@ -150,96 +204,71 @@ def to_thalassa(folder, freq=None, **kwargs): stations = gp.GeoDataFrame.from_file(b.obs) - ## save results in lead chunks - lpath = os.path.join(rpath, "skill") + # assign unique id + if "id" not in stations.columns: + stations["id"] = [f"IOC-{x}" for x in stations.ioc_code] + st = st.assign_coords({"id": ("node", stations.id.values)}).swap_dims({"node": "id"}).reset_coords("node") + + logger.info("save time series depending on lead time\n") + + if leads: + skill_path = os.path.join(rpath, "skill") + total_hours = pd.to_timedelta(b.time_frame) / pd.Timedelta(hours=1) + dt = total_hours / leads + + if dt % 1 == 0.0: + save_leads(stations, st, b.start_date, dt, leads, rpath=skill_path) - for l in range(freq): - from_date = b.start_date + pd.to_timedelta("{}H".format(l * 12)) + pd.to_timedelta("1S") - to_date = b.start_date + pd.to_timedelta("{}H".format((l + 1) * 12)) - h = st.sel(time=slice(from_date, to_date), drop=True) - h = h.rename({"elev": "elev_fct", "time": "ftime"}) - h = h.assign_coords({"node": stations.ioc_code.to_list()}) - leadfile = os.path.join(lpath, f"lead{l}") - if not os.path.exists(leadfile): - h.to_zarr(store=leadfile, mode="a") else: - h.to_zarr(store=leadfile, mode="a", append_dim="ftime") + logger.warning("leads value not correct, aborting\n") - locations = stations.ioc_code.values + # save sim data - # save to files - logger.info("Construct station simulation data output\n") - # Construct Thalassa file + ids = stations.id.values stp = stations.to_xarray().rename({"index": "node"}) # stations - stp = stp.assign_coords({"node": locations}) + stp = stp.assign_coords({"id": ("node", ids)}).swap_dims({"node": "id"}).reset_coords("node") - st_ = st.assign_coords(node=locations) - st_ = st_.rename({"elev": "elev_sim", "time": "stime"}) + st_ = st.rename({"elev": "elev_sim", "time": "stime"}) vdata = xr.merge([stp, st_]) vdata = vdata.drop_vars("geometry") - logger.info("Save station simulation data output\n") + logger.info("save station simulation data output\n") output_path = os.path.join(rpath, filename) vdata.to_netcdf(output_path) logger.info(f"..done with {filename} file\n") - # get observations - ods = compute_obs(stations, b.start_date, b.end_date) - - save_ods(ods, **kwargs) - - # compute stats - st = st.assign({"ioc_code": ("node", stations["ioc_code"])}) - sts = compute_stats(st, ods) - save_stats(sts, stations, **kwargs) - - logger.info(f"post processing complete for folder {folder}\n") - - -def compute_stats(st, ods): - logger.info("Compute general statistics for station points\n") - - sts = [] - for inode in tqdm(st.ioc_code.values): - isim = st.where(st.ioc_code == inode).dropna(dim="node") - sim = isim.elev.to_dataframe().droplevel(1) - - obs = ods.sel(node=inode).to_dataframe().drop("node", axis=1) - - stable = get_stats(sim, obs) # Do general statitics - - sts.append(stable) - - return sts - - -def save_stats(sts, stations, **kwargs): +def to_obs(folder, **kwargs): + tag = kwargs.get("tag", "schism") rpath = kwargs.get("rpath", "./thalassa/") - logger.info("Save stats\n") + json_file = os.path.join(folder, "{}_model.json".format(tag)) + b = pyposeidon.model.read(json_file) - locations = stations.ioc_code.values - - stats = pd.DataFrame(sts) - stats.index.name = "node" - stats_ = stats.to_xarray().assign_coords({"node": locations}) # stats + stations = gp.GeoDataFrame.from_file(b.obs) - stp = stations.to_xarray().rename({"index": "node"}) # stations - stp = stp.assign_coords({"node": locations}) + # assign unique id + if "id" not in stations.columns: + stations["id"] = [f"IOC-{x}" for x in stations.ioc_code] - sdata = xr.merge([stp, stats_]) + # get observations last timestamp + obs_files_path = os.path.join(rpath, "obs/") + if not os.path.exists(obs_files_path): + os.makedirs(obs_files_path) - sdata = sdata.drop_vars("geometry") + obs_files = glob.glob(obs_files_path + "*.parquet") - rpath = kwargs.get("rpath", "./thalassa/") + if not obs_files: + start_date = b.start_date + else: + obs_ = pd.read_parquet(obs_files[0], engine="fastparquet") + start_date = obs_.index[-1] + pd.Timedelta("60S") - output_path = os.path.join(rpath, "stats.nc") + gather_obs_data(stations, start_date, b.end_date, rpath=obs_files_path) - sdata.to_netcdf(output_path) - logger.info(f"..done with stats file\n") + return diff --git a/pyposeidon/utils/pplot.py b/pyposeidon/utils/pplot.py index 8aa66f6f..7b7a8ad8 100644 --- a/pyposeidon/utils/pplot.py +++ b/pyposeidon/utils/pplot.py @@ -219,14 +219,14 @@ def contour(self, ax=None, it=None, **kwargs): tes_var = kwargs.get("e", "SCHISM_hgrid_face_nodes") t_var = kwargs.get("t", "time") - x = self._obj[x_var][:].values - y = self._obj[y_var][:].values + x = self._obj[x_var][:] + y = self._obj[y_var][:] try: - t = self._obj[t_var][:].values + t = self._obj[t_var].data except: pass - tes = self._obj[tes_var].values[:, :4] + tes = self._obj[tes_var] # check tesselation @@ -332,14 +332,14 @@ def contourf(self, ax=None, it=None, **kwargs): tes_var = kwargs.get("e", "SCHISM_hgrid_face_nodes") t_var = kwargs.get("t", "time") - x = self._obj[x_var][:].values - y = self._obj[y_var][:].values + x = self._obj[x_var] + y = self._obj[y_var] try: - t = self._obj[t_var][:].values + t = self._obj[t_var].data except: pass - tes = self._obj[tes_var].values[:, :4] + tes = self._obj[tes_var] # check tesselation @@ -372,7 +372,18 @@ def contourf(self, ax=None, it=None, **kwargs): pass var = kwargs.get("var", "depth") - z = kwargs.get("z", self._obj[var].values[it, :].flatten()) + + if len(self._obj[var].shape) == 1: + zv = self._obj[var].data.flatten() + elif len(self._obj[var].shape) == 2: + if t_var in self._obj[var].coords: + zv = self._obj[var][it, :].data.flatten() + else: + raise Exception(f"{t_var} not in {var} dims") + else: + raise Exception(f"{var} dimension is larger than 2, please subset") + + z = kwargs.get("z", zv) nv = kwargs.get("nv", 10) @@ -406,7 +417,7 @@ def contourf(self, ax=None, it=None, **kwargs): vrange = np.linspace(vmin, vmax, nv, endpoint=True) - xy = kwargs.get("xy", (0.05, -0.1)) + xy = kwargs.get("xy", (-0.1, -0.3)) for val in [ "x", @@ -450,10 +461,10 @@ def quiver(self, ax=None, it=None, u=None, v=None, title=None, scale=0.1, color= y_var = kwargs.get("y", "SCHISM_hgrid_node_y") t_var = kwargs.get("t", "time") - x = self._obj[x_var][:].values - y = self._obj[y_var][:].values + x = self._obj[x_var][:] + y = self._obj[y_var][:] try: - t = self._obj[t_var][:].values + t = self._obj[t_var].data except: pass @@ -516,9 +527,9 @@ def mesh(self, ax=None, lw: float = 0.5, markersize: float = 1.0, **kwargs): y_var = kwargs.get("y", "SCHISM_hgrid_node_y") tes_var = kwargs.get("e", "SCHISM_hgrid_face_nodes") - x = self._obj[x_var][:].values - y = self._obj[y_var][:].values - tes = self._obj[tes_var].values[:, :4] + x = self._obj[x_var][:] + y = self._obj[y_var][:] + tes = self._obj[tes_var] # check tesselation @@ -591,9 +602,9 @@ def qframes(self, ax=None, u=None, v=None, scale=0.01, color="k", **kwargs): y_var = kwargs.get("y", "SCHISM_hgrid_node_y") t_var = kwargs.get("t", "time") - x = self._obj[x_var][:].values - y = self._obj[y_var][:].values - t = self._obj[t_var][:].values + x = self._obj[x_var] + y = self._obj[y_var] + t = self._obj[t_var] cr = kwargs.get("coastlines", None) c_attrs = kwargs.get("coastlines_attrs", {}) @@ -659,10 +670,10 @@ def frames(self, **kwargs): tes_var = kwargs.get("e", "SCHISM_hgrid_face_nodes") t_var = kwargs.get("t", "time") - x = self._obj[x_var][:].values - y = self._obj[y_var][:].values - t = self._obj[t_var][:].values - tes = self._obj[tes_var].values[:, :4] + x = self._obj[x_var] + y = self._obj[y_var] + t = self._obj[t_var].data + tes = self._obj[tes_var] # check tesselation @@ -688,7 +699,7 @@ def frames(self, **kwargs): tri3 = tri3 - 1 # fortran/python conversion var = kwargs.get("var", "depth") - z = kwargs.get("z", self._obj[var].values) + z = kwargs.get("z", self._obj[var]) # set figure size xr = x.max() - x.min() diff --git a/pyposeidon/utils/seam.py b/pyposeidon/utils/seam.py index 726fbdaf..42c3da88 100644 --- a/pyposeidon/utils/seam.py +++ b/pyposeidon/utils/seam.py @@ -146,7 +146,9 @@ def get_seam(x, y, z, tri3, **kwargs): return xx, yy, ges.values -def to_2d(dataset=None, var=None, mesh=None, **kwargs): +def to_2d( + dataset: xr.Dataset = None, data_vars: list[str] = None, mesh: list[np.array] = None, **kwargs +) -> xr.Dataset: x_var = kwargs.get("x", "SCHISM_hgrid_node_x") y_var = kwargs.get("y", "SCHISM_hgrid_node_y") tes_var = kwargs.get("e", "SCHISM_hgrid_face_nodes") @@ -182,44 +184,19 @@ def to_2d(dataset=None, var=None, mesh=None, **kwargs): orig = pyresample.geometry.SwathDefinition(lons=px - d, lats=py) targ = pyresample.geometry.SwathDefinition(lons=pxi - d, lats=pyi) - # Resample - if len(dataset[var].shape) == 1: - z = dataset[var].values - zm = z[xmask] - z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000) # , fill_value=0) - xelev = np.concatenate((z, z_)) - - # create xarray - xe = xr.Dataset( - { - var: (["nSCHISM_hgrid_node"], xelev), - "SCHISM_hgrid_node_x": (["nSCHISM_hgrid_node"], xn), - "SCHISM_hgrid_node_y": (["nSCHISM_hgrid_node"], yn), - "SCHISM_hgrid_face_nodes": ( - ["nSCHISM_hgrid_face", "nMaxSCHISM_hgrid_face_nodes"], - tri3n, - ), - }, - ) - - elif "time" in dataset[var].coords: - it_start = kwargs.get("it_start", 0) - it_end = kwargs.get("it_end", dataset.time.shape[0]) - - if not os.path.exists("./seamtmp/"): - os.makedirs("./seamtmp/") - - for i in tqdm(range(it_start, it_end)): - z = dataset[var].values[i, :] + xes = [] + for var in data_vars: + # Resample + if len(dataset[var].shape) == 1: + z = dataset[var].values zm = z[xmask] - z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000, fill_value=0) - e = np.concatenate((z, z_)) - e = e[np.newaxis, :] # make 2d + z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000) # , fill_value=0) + xelev = np.concatenate((z, z_)) # create xarray - xi = xr.Dataset( + xe = xr.Dataset( { - var: (["time", "nSCHISM_hgrid_node"], e), + var: (["nSCHISM_hgrid_node"], xelev), "SCHISM_hgrid_node_x": (["nSCHISM_hgrid_node"], xn), "SCHISM_hgrid_node_y": (["nSCHISM_hgrid_node"], yn), "SCHISM_hgrid_face_nodes": ( @@ -227,20 +204,54 @@ def to_2d(dataset=None, var=None, mesh=None, **kwargs): tri3n, ), }, - coords={"time": ("time", [dataset.time.values[i]])}, ) - xi.to_netcdf("./seamtmp/x_{:03d}.nc".format(i)) + xe.attrs.update(dataset.attrs) - xe = xr.open_mfdataset("./seamtmp/x_*.nc", data_vars="minimal") + elif "time" in dataset[var].coords: + it_start = kwargs.get("it_start", 0) + it_end = kwargs.get("it_end", dataset.time.shape[0]) - # cleanup - xfiles = glob("./seamtmp/x_*.nc") - for f in xfiles: - os.remove(f) - os.removedirs("./seamtmp/") + if not os.path.exists("./seamtmp/"): + os.makedirs("./seamtmp/") - return xe + for i in tqdm(range(it_start, it_end)): + z = dataset[var].values[i, :] + zm = z[xmask] + z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000, fill_value=0) + e = np.concatenate((z, z_)) + e = e[np.newaxis, :] # make 2d + + # create xarray + xi = xr.Dataset( + { + var: (["time", "nSCHISM_hgrid_node"], e), + "SCHISM_hgrid_node_x": (["nSCHISM_hgrid_node"], xn), + "SCHISM_hgrid_node_y": (["nSCHISM_hgrid_node"], yn), + "SCHISM_hgrid_face_nodes": ( + ["nSCHISM_hgrid_face", "nMaxSCHISM_hgrid_face_nodes"], + tri3n, + ), + }, + coords={"time": ("time", [dataset.time.values[i]])}, + ) + + xi.to_netcdf("./seamtmp/x_{:03d}.nc".format(i)) + + xe = xr.open_mfdataset("./seamtmp/x_*.nc", data_vars="minimal") + xe.attrs.update(dataset.attrs) + + # cleanup + xfiles = glob("./seamtmp/x_*.nc") + for f in xfiles: + os.remove(f) + os.removedirs("./seamtmp/") + + xes.append(xe) + + xf = xr.merge(xes) + + return xf def reposition(px): diff --git a/pyposeidon/utils/statistics.py b/pyposeidon/utils/statistics.py index a0c9d328..8d4a0146 100755 --- a/pyposeidon/utils/statistics.py +++ b/pyposeidon/utils/statistics.py @@ -9,33 +9,34 @@ def get_stats(sim_, obs_): - # match time frames - try: + if obs_.dropna().empty: + stats = pd.Series( + index={ + "Mean Absolute Error", + "RMSE", + "Scatter Index", + "percentage RMSE", + "BIAS or mean error", + "Standard deviation of residuals", + "Correlation Coefficient", + "R^2", + "Nash-Sutcliffe Coefficient", + "lamda index", + }, + dtype="float64", + ) + return stats + + else: + # match time frames start = max(obs_.index[0], sim_.index[0]) end = min(obs_.index[-1], sim_.index[-1]) - obs_ = obs_.loc[start:end] - sim_ = sim_.loc[start:end] + + dates = pd.date_range(start, end, freq="s") + obs_ = obs_[obs_.index.isin(dates)] + sim_ = sim_[sim_.index.isin(dates)] obs_ = obs_.reindex(sim_.index, method="nearest") # sample on simulation times return vtable(obs_.values, sim_.values) - except: - if obs_.dropna().empty: - logger.warning("Observation data not available for this station") - stats = pd.Series( - index={ - "Mean Absolute Error", - "RMSE", - "Scatter Index", - "percentage RMSE", - "BIAS or mean error", - "Standard deviation of residuals", - "Correlation Coefficient", - "R^2", - "Nash-Sutcliffe Coefficient", - "lamda index", - }, - dtype="float64", - ) - return stats def vtable(obsrv, model): diff --git a/tests/test_schism.py b/tests/test_schism.py index c55ebd9f..20d154a1 100644 --- a/tests/test_schism.py +++ b/tests/test_schism.py @@ -84,6 +84,32 @@ } +case4 = { + "solver_name": "schism", + "mesh_file": MESH_FILE, + "manning": 0.12, + "windrot": 0.00001, + "tag": "test", + "start_date": "2011-1-1 0:0:0", + "time_frame": "12H", + "meteo_source": [(DATA_DIR / "era5.grib").as_posix()], # meteo file + "dem_source": DEM_FILE, + "monitor": True, + "update": ["all"], # update only meteo, keep dem + "parameters": { + "dt": 400, + "rnday": 0.3, + "nhot": 0, + "ihot": 0, + "nspool": 9, + "ihfskip": 36, + "nhot_write": 108, + "nc_out": 0, + }, + "scribes": 2, +} + + def schism(tmpdir, dic): # initialize a model rpath = str(tmpdir) + "/" @@ -105,7 +131,7 @@ def schism(tmpdir, dic): @pytest.mark.schism -@pytest.mark.parametrize("case", [case1, case2, case3]) +@pytest.mark.parametrize("case", [case1, case2, case3, case4]) def test_answer(tmpdir, case): assert schism(tmpdir, case) == True diff --git a/tests/test_schism_cast.py b/tests/test_schism_cast.py index 5095811a..9dbede08 100644 --- a/tests/test_schism_cast.py +++ b/tests/test_schism_cast.py @@ -135,7 +135,8 @@ def test_schism_cast(tmpdir, copy): symlinked_file = os.path.join(next_rpath, filename) assert os.path.exists(symlinked_file) assert os.path.islink(symlinked_file) - assert os.path.realpath(symlinked_file) == original_file + if not os.path.islink(original_file): + assert os.path.realpath(symlinked_file) == original_file @pytest.mark.schism