From a1e88dfeb1bb84d7b751ec46cbe8a948d53101e2 Mon Sep 17 00:00:00 2001 From: George Breyiannis Date: Sun, 1 Oct 2023 15:33:46 +0200 Subject: [PATCH] seam: support list of variables - resolve #164 moving to list of variables as argument for one call functionality --- pyposeidon/utils/post.py | 6 +-- pyposeidon/utils/seam.py | 100 +++++++++++++++++++++------------------ 2 files changed, 54 insertions(+), 52 deletions(-) diff --git a/pyposeidon/utils/post.py b/pyposeidon/utils/post.py index 22c1b825..8457bd0d 100644 --- a/pyposeidon/utils/post.py +++ b/pyposeidon/utils/post.py @@ -180,11 +180,7 @@ def to_thalassa(folder, **kwargs): logger.info("converting to 2D\n") # Convert to 2D [xn, yn, tri3n] = np.load(to2d, allow_pickle=True) - sv = [] - for var in rvars: - isv = to_2d(out, var=var, mesh=[xn, yn, tri3n]) # elevation - sv.append(isv) - out = xr.merge(sv) + out = to_2d(out, data_vars=rvars, mesh=[xn, yn, tri3n]) # elevation else: rvars_ = rvars + [x_var, y_var, tes_var] diff --git a/pyposeidon/utils/seam.py b/pyposeidon/utils/seam.py index 6eb925ce..142d8d00 100644 --- a/pyposeidon/utils/seam.py +++ b/pyposeidon/utils/seam.py @@ -146,7 +146,7 @@ def get_seam(x, y, z, tri3, **kwargs): return xx, yy, ges.values -def to_2d(dataset=None, var=None, mesh=None, **kwargs): +def to_2d(dataset: xr.Dataset = None, data_vars: list[str] = None, mesh: list[np.array] = None, **kwargs) -> xr.Dataset: x_var = kwargs.get("x", "SCHISM_hgrid_node_x") y_var = kwargs.get("y", "SCHISM_hgrid_node_y") tes_var = kwargs.get("e", "SCHISM_hgrid_face_nodes") @@ -182,46 +182,19 @@ def to_2d(dataset=None, var=None, mesh=None, **kwargs): orig = pyresample.geometry.SwathDefinition(lons=px - d, lats=py) targ = pyresample.geometry.SwathDefinition(lons=pxi - d, lats=pyi) - # Resample - if len(dataset[var].shape) == 1: - z = dataset[var].values - zm = z[xmask] - z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000) # , fill_value=0) - xelev = np.concatenate((z, z_)) - - # create xarray - xe = xr.Dataset( - { - var: (["nSCHISM_hgrid_node"], xelev), - "SCHISM_hgrid_node_x": (["nSCHISM_hgrid_node"], xn), - "SCHISM_hgrid_node_y": (["nSCHISM_hgrid_node"], yn), - "SCHISM_hgrid_face_nodes": ( - ["nSCHISM_hgrid_face", "nMaxSCHISM_hgrid_face_nodes"], - tri3n, - ), - }, - ) - - xe.attrs.update(dataset.attrs) - - elif "time" in dataset[var].coords: - it_start = kwargs.get("it_start", 0) - it_end = kwargs.get("it_end", dataset.time.shape[0]) - - if not os.path.exists("./seamtmp/"): - os.makedirs("./seamtmp/") - - for i in tqdm(range(it_start, it_end)): - z = dataset[var].values[i, :] + xes = [] + for var in data_vars: + # Resample + if len(dataset[var].shape) == 1: + z = dataset[var].values zm = z[xmask] - z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000, fill_value=0) - e = np.concatenate((z, z_)) - e = e[np.newaxis, :] # make 2d + z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000) # , fill_value=0) + xelev = np.concatenate((z, z_)) # create xarray - xi = xr.Dataset( + xe = xr.Dataset( { - var: (["time", "nSCHISM_hgrid_node"], e), + var: (["nSCHISM_hgrid_node"], xelev), "SCHISM_hgrid_node_x": (["nSCHISM_hgrid_node"], xn), "SCHISM_hgrid_node_y": (["nSCHISM_hgrid_node"], yn), "SCHISM_hgrid_face_nodes": ( @@ -229,21 +202,54 @@ def to_2d(dataset=None, var=None, mesh=None, **kwargs): tri3n, ), }, - coords={"time": ("time", [dataset.time.values[i]])}, ) - xi.to_netcdf("./seamtmp/x_{:03d}.nc".format(i)) + xe.attrs.update(dataset.attrs) - xe = xr.open_mfdataset("./seamtmp/x_*.nc", data_vars="minimal") - xe.attrs.update(dataset.attrs) + elif "time" in dataset[var].coords: + it_start = kwargs.get("it_start", 0) + it_end = kwargs.get("it_end", dataset.time.shape[0]) - # cleanup - xfiles = glob("./seamtmp/x_*.nc") - for f in xfiles: - os.remove(f) - os.removedirs("./seamtmp/") + if not os.path.exists("./seamtmp/"): + os.makedirs("./seamtmp/") - return xe + for i in tqdm(range(it_start, it_end)): + z = dataset[var].values[i, :] + zm = z[xmask] + z_ = pyresample.kd_tree.resample_nearest(orig, zm, targ, radius_of_influence=200000, fill_value=0) + e = np.concatenate((z, z_)) + e = e[np.newaxis, :] # make 2d + + # create xarray + xi = xr.Dataset( + { + var: (["time", "nSCHISM_hgrid_node"], e), + "SCHISM_hgrid_node_x": (["nSCHISM_hgrid_node"], xn), + "SCHISM_hgrid_node_y": (["nSCHISM_hgrid_node"], yn), + "SCHISM_hgrid_face_nodes": ( + ["nSCHISM_hgrid_face", "nMaxSCHISM_hgrid_face_nodes"], + tri3n, + ), + }, + coords={"time": ("time", [dataset.time.values[i]])}, + ) + + xi.to_netcdf("./seamtmp/x_{:03d}.nc".format(i)) + + xe = xr.open_mfdataset("./seamtmp/x_*.nc", data_vars="minimal") + xe.attrs.update(dataset.attrs) + + # cleanup + xfiles = glob("./seamtmp/x_*.nc") + for f in xfiles: + os.remove(f) + os.removedirs("./seamtmp/") + + xes.append(xe) + + xf = xr.merge(xes) + + return xf def reposition(px):