Skip to content

Commit

Permalink
added mesh_check() function
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsail committed Dec 21, 2023
1 parent a3572f0 commit ebb750b
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 67 deletions.
4 changes: 2 additions & 2 deletions pyposeidon/boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, **kwargs):
cbuffer = kwargs.get("cbuffer", None)
blevels = kwargs.get("blevels", None)
prad = kwargs.get("R", 1.0)
rpath = kwargs.get("rpath", '.')
rpath = kwargs.get("rpath", ".")

# COASTLINES
if coastlines is None:
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(self, **kwargs):
os.makedirs(gpath)

self.coasts.set_crs(epsg=4326, inplace=True)
self.coasts.to_file(os.path.join(gpath, "coasts.shp"), driver = 'ESRI Shapefile')
self.coasts.to_file(os.path.join(gpath, "coasts.shp"), driver="ESRI Shapefile")
df = global_tag(self.coasts, cbuffer, blevels, R=prad)
elif isinstance(self.geometry, gp.GeoDataFrame):
df = self.geometry
Expand Down
207 changes: 142 additions & 65 deletions pyposeidon/moceanmesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,76 @@ def get_ibounds(df, mm):
return ibounds


def check_mesh(ds, stereo_to_ll=False):
"""
Check the mesh and reverse any counter-clockwise (CW) triangles.
If the determinant of the cross product of two edges of a triangle is negative,
the triangle is CW. This function reverses the triangle if it is CW.
Additionally, this function removes any flat (degenerate) triangles from the mesh.
A triangle is considered flat if its area is less than a threshold value.
Parameters
----------
ds : xr.Dataset
Returns
-------
ds : xr.Dataset
TODO: MOVE THIS FUNCTION IN A COMMON SPACE (FOR OTHER MESHERS)
"""
logger.info("checking mesh..\n")

print(ds)
tris = ds.SCHISM_hgrid_face_nodes.data
x = ds.SCHISM_hgrid_node_x.data
y = ds.SCHISM_hgrid_node_y.data

t12 = -x[tris[:, 0]] + x[tris[:, 1]]
t13 = -x[tris[:, 0]] + x[tris[:, 2]]
t22 = -y[tris[:, 0]] + y[tris[:, 1]]
t23 = -y[tris[:, 0]] + y[tris[:, 2]]
#
det = t12 * t23 - t22 * t13 # as defined in GEOELT (sources/utils/geoelt.f)
#
ccw_ = det > 0
cw_ = det < 0
flat_ = abs(det) <= 10e-6

# Reverse CW triangles
ds.SCHISM_hgrid_face_nodes[~ccw_] = ds.SCHISM_hgrid_face_nodes[~ccw_][:, ::-1]
if cw_.sum() > 0:
logger.info(" > reversed " + str(cw_.sum()) + " CW triangles")

# CREATE NEW DATASET TO REMOVE FLAT TRIANGLES
if flat_.sum() > 0:
non_flat_tris = tris[~flat_]
nodes_ = np.unique(non_flat_tris)

# Create a mapping from old indices to new indices
map_old_new_ = {old_idx: new_idx for new_idx, old_idx in enumerate(nodes_)}

# Apply the mapping to update the indices in non_flat_tris
non_flat_tris_mapped = np.vectorize(map_old_new_.get)(non_flat_tris)

x_ = ds.SCHISM_hgrid_node_x.values[nodes_]
y_ = ds.SCHISM_hgrid_node_y.values[nodes_]

# using oceanmesh to cleanup and fix the mesh
points, cells = om.make_mesh_boundaries_traversable(np.column_stack((x_, y_)), non_flat_tris_mapped)
points, cells = om.delete_faces_connected_to_one_face(points, cells)

ds = om_to_xarray(points, cells, stereo_to_ll=stereo_to_ll)
logger.info(" > removed " + str(flat_.sum()) + " flat triangles")
logger.info(f" > Filtered {len(ds.node.values) - len(ds.node.values)} boundary nodes")
logger.info(f" > removed {len(tris) - len(cells)} elements in total.\n")

return ds


def get(contours, **kwargs):
"""
Create a `oceanmesh` mesh.
Expand Down Expand Up @@ -485,7 +555,6 @@ def make_oceanmesh(df, **kwargs):
tria = tria.apply(pd.to_numeric)

# boundaries

logger.info("oceanmesh: boundaries")
df_open = df[df["tag"] == "open"]
opengp = gp.GeoDataFrame(df_open, geometry="geometry")
Expand Down Expand Up @@ -549,6 +618,76 @@ def make_oceanmesh(df, **kwargs):
return gr


def om_to_xarray(points, cells, stereo_to_ll=True):
nodes = pd.DataFrame(
data={
"x": points[:, 0],
"y": points[:, 1],
"z": np.zeros(len(points[:, 0])),
}
)
tria = pd.DataFrame(
data={
"a": cells[:, 0],
"b": cells[:, 1],
"c": cells[:, 2],
}
)
nodes = nodes.apply(pd.to_numeric)
tria = tria.apply(pd.to_numeric)

# boundaries (all are islands)
logger.info("oceanmesh: boundaries")

bounds = om.edges.get_boundary_edges(cells)
if len(bounds) > 0:
tbf = pd.DataFrame({"node": np.unique(np.array(bounds).flatten())})
tbf["type"] = "island"
tbf["id"] = -1
else:
tbf = None

tbf = tbf.reset_index(drop=True)
tbf.index.name = "bnodes"

# convert to lat/lon
if stereo_to_ll:
if nodes.z.any() != 0:
xd, yd = to_lat_lon(nodes.x, nodes.y, nodes.z)
nodes["x"] = xd
nodes["y"] = yd
else:
xd, yd = to_lat_lon(nodes.x, nodes.y)
nodes["x"] = xd
nodes["y"] = yd
else:
nodes["x"] = nodes.x
nodes["y"] = nodes.y

els = xr.DataArray(
tria.loc[:, ["a", "b", "c"]].values,
dims=["nSCHISM_hgrid_face", "nMaxSCHISM_hgrid_face_nodes"],
name="SCHISM_hgrid_face_nodes",
)

nod = (
nodes.loc[:, ["x", "y"]]
.to_xarray()
.rename(
{
"index": "nSCHISM_hgrid_node",
"x": "SCHISM_hgrid_node_x",
"y": "SCHISM_hgrid_node_y",
}
)
)
nod = nod.drop_vars("nSCHISM_hgrid_node")

dep = xr.Dataset({"depth": (["nSCHISM_hgrid_node"], np.zeros(nod.nSCHISM_hgrid_node.shape[0]))})

return xr.merge([nod, dep, els, tbf.to_xarray()]) # total


def make_oceanmesh_global(df, **kwargs):
logger.info("Executing oceanmesh")
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -692,69 +831,7 @@ def make_oceanmesh_global(df, **kwargs):
logger.info("oceanmesh: apply Laplacian smoother")
points, cells = om.laplacian2(points, cells, max_iter=iter, pfix=pfix)

nodes = pd.DataFrame(
data={
"x": points[:, 0],
"y": points[:, 1],
"z": np.zeros(len(points[:, 0])),
}
)
tria = pd.DataFrame(
data={
"a": cells[:, 0],
"b": cells[:, 1],
"c": cells[:, 2],
}
)
nodes = nodes.apply(pd.to_numeric)
tria = tria.apply(pd.to_numeric)

# boundaries (all are islands)

logger.info("oceanmesh: boundaries")

bounds = om.edges.get_boundary_edges(cells)
if len(bounds) > 0:
tbf = pd.DataFrame({"node": np.unique(np.array(bounds).flatten())})
tbf["type"] = "island"
tbf["id"] = -1
else:
tbf = None

tbf = tbf.reset_index(drop=True)
tbf.index.name = "bnodes"

# convert to lat/lon
if nodes.z.any() != 0:
xd, yd = to_lat_lon(nodes.x, nodes.y, nodes.z)
nodes["x"] = xd
nodes["y"] = yd
else:
xd, yd = to_lat_lon(nodes.x, nodes.y)
nodes["x"] = xd
nodes["y"] = yd

els = xr.DataArray(
tria.loc[:, ["a", "b", "c"]].values,
dims=["nSCHISM_hgrid_face", "nMaxSCHISM_hgrid_face_nodes"],
name="SCHISM_hgrid_face_nodes",
)

nod = (
nodes.loc[:, ["x", "y"]]
.to_xarray()
.rename(
{
"index": "nSCHISM_hgrid_node",
"x": "SCHISM_hgrid_node_x",
"y": "SCHISM_hgrid_node_y",
}
)
)
nod = nod.drop_vars("nSCHISM_hgrid_node")

dep = xr.Dataset({"depth": (["nSCHISM_hgrid_node"], np.zeros(nod.nSCHISM_hgrid_node.shape[0]))})

gr = xr.merge([nod, dep, els, tbf.to_xarray()]) # total
gr = om_to_xarray(points, cells)
gr = check_mesh(gr)

return gr

0 comments on commit ebb750b

Please sign in to comment.