From 8e85173e7d2789edbd265e23ed0a81d261265d66 Mon Sep 17 00:00:00 2001 From: tomsail Date: Sat, 14 Sep 2024 15:29:33 +0200 Subject: [PATCH] fix: lint --- pyposeidon/telemac.py | 32 ++++++++++++++++---------------- pyposeidon/utils/cfl.py | 4 +--- pyposeidon/utils/cpoint.py | 4 +--- pyposeidon/utils/obs.py | 20 +++++++++++++++++--- pyposeidon/utils/pplot.py | 1 + tests/test_telemac.py | 2 +- tests/utils/test_cpoint.py | 24 ++++++++++++++---------- tests/utils/test_obs.py | 24 +++++++++++++----------- 8 files changed, 64 insertions(+), 47 deletions(-) diff --git a/pyposeidon/telemac.py b/pyposeidon/telemac.py index 64fb0de..4858724 100644 --- a/pyposeidon/telemac.py +++ b/pyposeidon/telemac.py @@ -194,13 +194,19 @@ def write_netcdf(ds, outpath): def extract_t_elev_2D( - ds: xr.Dataset, x: float, y: float, var: str = "elev", xstr: str = "longitude", ystr: str = "latitude", max_dist: float = 1000, + ds: xr.Dataset, + x: float, + y: float, + var: str = "elev", + xstr: str = "longitude", + ystr: str = "latitude", + max_dist: float = 1000, ): lons, lats = ds[xstr].values, ds[ystr].values - mesh = pd.DataFrame(np.vstack([x, y]).T, columns = ["lon", "lat"]) - points = pd.DataFrame(np.vstack([lons, lats]).T, columns = ["lon", "lat"]) + mesh = pd.DataFrame(np.vstack([x, y]).T, columns=["lon", "lat"]) + points = pd.DataFrame(np.vstack([lons, lats]).T, columns=["lon", "lat"]) df = find_nearest_nodes(mesh, points, 1) - df = df[df.distance < max_dist] + df = df[df.distance < max_dist] indx = df["mesh_index"] ds_ = ds.isel(node=indx.values[0]) out_ = ds_[var].values @@ -1102,7 +1108,7 @@ def run(self, api=True, **kwargs): return if self.fortran: - user_fortran = 'user_fortran' + user_fortran = "user_fortran" else: user_fortran = None @@ -1430,16 +1436,10 @@ def set_obs(self, **kwargs): return mesh = pd.DataFrame( - np.array( - [ - self.mesh.Dataset.SCHISM_hgrid_node_x.values, - self.mesh.Dataset.SCHISM_hgrid_node_y.values - ] - ).T, - columns = ["lon", "lat"]) - points = pd.DataFrame( - np.array([tgn.longitude.values, tgn.latitude.values]).T, - columns = ["lon", "lat"]) + np.array([self.mesh.Dataset.SCHISM_hgrid_node_x.values, self.mesh.Dataset.SCHISM_hgrid_node_y.values]).T, + columns=["lon", "lat"], + ) + points = pd.DataFrame(np.array([tgn.longitude.values, tgn.latitude.values]).T, columns=["lon", "lat"]) df = find_nearest_nodes(mesh, points, 1) df = df[df.distance < max_dist] @@ -1453,7 +1453,7 @@ def set_obs(self, **kwargs): # convert to MERCATOR coordinates # dirty fix (this needs to be fixed in TELEMAC directly) - x, y = longlat2spherical(df["lon"], df["lat"],0,0) + x, y = longlat2spherical(df["lon"], df["lat"], 0, 0) df["x"] = x df["y"] = y diff --git a/pyposeidon/utils/cfl.py b/pyposeidon/utils/cfl.py index e769844..7100d2b 100644 --- a/pyposeidon/utils/cfl.py +++ b/pyposeidon/utils/cfl.py @@ -103,9 +103,7 @@ def parse_hgrid(path: os.PathLike[str] | str) -> dict[str, T.Any]: no_closed_boundaries = int(fd.readline().split(b"=")[0].strip()) total_closed_boundary_nodes = int(fd.readline().split(b"=")[0].strip()) for i in range(no_closed_boundaries): - no_nodes_in_boundary, boundary_type = map( - int, (fd.readline().split(b"=")[0].strip().split(b" ")) - ) + no_nodes_in_boundary, boundary_type = map(int, (fd.readline().split(b"=")[0].strip().split(b" "))) boundary_nodes = np.fromiter( fd, count=no_nodes_in_boundary, diff --git a/pyposeidon/utils/cpoint.py b/pyposeidon/utils/cpoint.py index 754fac7..8ef4cf4 100644 --- a/pyposeidon/utils/cpoint.py +++ b/pyposeidon/utils/cpoint.py @@ -101,6 +101,4 @@ def find_nearest_nodes( .assign(distance=(distances.flatten() * earth_radius)) .reset_index(names=["mesh_index"]) ) - return pd.concat( - (points.loc[points.index.repeat(k)].reset_index(drop=True), closest_nodes), axis="columns" - ) + return pd.concat((points.loc[points.index.repeat(k)].reset_index(drop=True), closest_nodes), axis="columns") diff --git a/pyposeidon/utils/obs.py b/pyposeidon/utils/obs.py index b9ef34c..ece45fd 100644 --- a/pyposeidon/utils/obs.py +++ b/pyposeidon/utils/obs.py @@ -1,4 +1,5 @@ """ Observational Data retrieval """ + from __future__ import annotations import itertools @@ -152,14 +153,27 @@ def serialize_stations( msg = f"stations must have these columns too: {mandatory_cols.difference(df_cols)}" raise ValueError(msg) # - basic_cols = ["mesh_lon", "mesh_lat", "z", "separator", "unique_id", "mesh_index", "lon", "lat", "depth", "distance"] + basic_cols = [ + "mesh_lon", + "mesh_lat", + "z", + "separator", + "unique_id", + "mesh_index", + "lon", + "lat", + "depth", + "distance", + ] station_in = stations.assign( z=0, separator="\t!\t", ) - station_in = station_in.set_index(station_in.index +1) + station_in = station_in.set_index(station_in.index + 1) station_in = station_in[basic_cols] with open(f"{path}", "w") as fd: - fd.write(f"{schism_station_flag.strip()}\t ! https://schism-dev.github.io/schism/master/input-output/optional-inputs.html#stationin-bp-format\n") + fd.write( + f"{schism_station_flag.strip()}\t ! https://schism-dev.github.io/schism/master/input-output/optional-inputs.html#stationin-bp-format\n" + ) fd.write(f"{len(station_in)}\t ! number of stations\n") station_in.to_csv(fd, header=None, sep=" ", float_format="%.10f") diff --git a/pyposeidon/utils/pplot.py b/pyposeidon/utils/pplot.py index 8f7af64..de87d4f 100644 --- a/pyposeidon/utils/pplot.py +++ b/pyposeidon/utils/pplot.py @@ -23,6 +23,7 @@ import sys import os + # from pyposeidon.tools import to_geodataframe ffmpeg = sys.exec_prefix + "/bin/ffmpeg" diff --git a/tests/test_telemac.py b/tests/test_telemac.py index 6f90393..63f4505 100644 --- a/tests/test_telemac.py +++ b/tests/test_telemac.py @@ -79,7 +79,7 @@ }, } -case4 = { # test does not work with telemac3d: mesh quality is too bad +case4 = { # test does not work with telemac3d: mesh quality is too bad "solver_name": "telemac", "mesh_file": MESH_FILE, "module": "telemac3d", diff --git a/tests/utils/test_cpoint.py b/tests/utils/test_cpoint.py index e094265..c26b969 100644 --- a/tests/utils/test_cpoint.py +++ b/tests/utils/test_cpoint.py @@ -11,19 +11,23 @@ @pytest.fixture(scope="session") def mesh_nodes(): - return pd.DataFrame({ - "lon": [0, 10, 20], - "lat": [0, 5, 0], - }) + return pd.DataFrame( + { + "lon": [0, 10, 20], + "lat": [0, 5, 0], + } + ) @pytest.fixture(scope="session") def points(): - return pd.DataFrame({ - "lon": [1, 11, 21, 2], - "lat": [1, 4, 1, 2], - "id": ["a", "b", "c", "d"], - }) + return pd.DataFrame( + { + "lon": [1, 11, 21, 2], + "lat": [1, 4, 1, 2], + "id": ["a", "b", "c", "d"], + } + ) @pytest.fixture(scope="session") @@ -41,7 +45,7 @@ def test_find_nearest_nodes(mesh_nodes, points): assert nearest_nodes.distance.max() < 320_000 -@pytest.mark.parametrize("k", [pytest.param(2, id='2 points'), pytest.param(3, id='3 points')]) +@pytest.mark.parametrize("k", [pytest.param(2, id="2 points"), pytest.param(3, id="3 points")]) def test_find_nearest_nodes_multiple_points_and_pass_tree_as_argument(mesh_nodes, points, k, ball_tree): nearest_nodes = find_nearest_nodes(mesh_nodes, points, k=k, tree=ball_tree) assert isinstance(nearest_nodes, pd.DataFrame) diff --git a/tests/utils/test_obs.py b/tests/utils/test_obs.py index 4733953..d1b8718 100644 --- a/tests/utils/test_obs.py +++ b/tests/utils/test_obs.py @@ -17,17 +17,19 @@ def test_serialize_stations(tmp_path): 3 20.0000000000 0.0000000000 0 ! c 2 21.0000000000 1.0000000000 1 157249.3812719441 """ ) - stations = pd.DataFrame({ - 'lon': [1., 11., 21.], - 'lat': [1., 4., 1.], - 'unique_id': ["a", "b", "c"], - 'extra_col': ["AA", "BB", "CC"], - 'mesh_index': [0, 1, 2], - 'mesh_lon': [0., 10., 20.], - 'mesh_lat': [0., 5., 0.], - 'distance': [157249.38127194397, 157010.16264060183, 157249.38127194406], - 'depth': [3, 5, 1], - }) + stations = pd.DataFrame( + { + "lon": [1.0, 11.0, 21.0], + "lat": [1.0, 4.0, 1.0], + "unique_id": ["a", "b", "c"], + "extra_col": ["AA", "BB", "CC"], + "mesh_index": [0, 1, 2], + "mesh_lon": [0.0, 10.0, 20.0], + "mesh_lat": [0.0, 5.0, 0.0], + "distance": [157249.38127194397, 157010.16264060183, 157249.38127194406], + "depth": [3, 5, 1], + } + ) path = tmp_path / "station.in" serialize_stations(stations, path) contents = path.read_text()