From 76702750839bca62ece964cbd7a34e267f24a01b Mon Sep 17 00:00:00 2001 From: Brendan Date: Sat, 26 Oct 2024 09:36:14 +0200 Subject: [PATCH 1/4] Perf prog --- src/fiat/gis/overlay.py | 280 +++++++++++++++++++++++---------- src/fiat/models/util.py | 4 +- src/fiat/models/worker_geom.py | 21 ++- test/test_gis.py | 22 +-- 4 files changed, 221 insertions(+), 106 deletions(-) diff --git a/src/fiat/gis/overlay.py b/src/fiat/gis/overlay.py index 3925fb4..afe01ac 100644 --- a/src/fiat/gis/overlay.py +++ b/src/fiat/gis/overlay.py @@ -1,39 +1,141 @@ """Combined vector and raster methods for FIAT.""" -from numpy import array -from osgeo import gdal, ogr, osr +from numpy import any, arange, array, ndarray, ones, stack, tile, zeros_like +from osgeo import ogr from fiat.gis.util import pixel2world, world2pixel from fiat.io import Grid +def intersect_cell( + geom: ogr.Geometry, + x: float | int, + y: float | int, + dx: float | int, + dy: float | int, +): + """_summary_. + + _extended_summary_ + + Parameters + ---------- + geom : ogr.Geometry + _description_ + x : float | int + _description_ + y : float | int + _description_ + dx : float | int + _description_ + dy : float | int + _description_ + """ + x = float(x) + y = float(y) + cell = ogr.Geometry(ogr.wkbPolygon) + ring = ogr.Geometry(ogr.wkbLinearRing) + ring.AddPoint(x, y) + ring.AddPoint(x + dx, y) + ring.AddPoint(x + dx, y + dy) + ring.AddPoint(x, y + dy) + ring.AddPoint(x, y) + cell.AddGeometry(ring) + return geom.Intersects(cell) + + +def rasterize( + x: ndarray, + y: ndarray, + geometry: list | tuple, +) -> ndarray: + """Rasterize a polygon according to the even odd rule. + + Depending on the input, it is either 'center only' or 'all touched'. + + Parameters + ---------- + x : ndarray + A 2D or 3D array of the x coordinates of the raster. + y : ndarray + A 2d or 3D array of the y coordinates of the raster. + geometry : list | tuple, + A list or tuple containing tuples of xy coordinates of the vertices. + + Returns + ------- + ndarray + The resulting binary rasterized polygon. + """ + # Set up array and information + n = len(geometry) + touched = zeros_like(x, dtype=bool) + p1x, p1y = geometry[0] + + for i in range(n + 1): + p2x, p2y = geometry[i % n] + + # Check for a point being inside + mask = (y >= min(p1y, p2y)) & (y <= max(p1y, p2y)) & (x <= max(p1x, p2x)) + if p1y != p2y: + xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x + mask &= (p1x == p2x) | (x <= xinters) + touched ^= mask + + # Set for next point + p1x, p1y = p2x, p2y + return touched + + +def vertices( + mask, + xmin, + xmax, + ymin, + ymax, + geometry, +): + """_summary_.""" + px = geometry[:, 0] + py = geometry[:, 1] + mask |= any( + ( + (px[:, None, None] > xmin) + & (px[:, None, None] < xmax) + & (py[:, None, None] > ymin) + & (py[:, None, None] < ymax) + ), + axis=0, + ) + return mask + + def clip( + ft: ogr.Feature, band: Grid, - srs: osr.SpatialReference, gtf: tuple, - ft: ogr.Feature, + all_touched: bool = False, ): """Clip a grid based on a feature (vector). Parameters ---------- + ft : ogr.Feature + A Feature according to the \ +[ogr module](https://gdal.org/api/python/osgeo.ogr.html) of osgeo. + Can be optained by indexing a \ +[GeomSource](/api/GeomSource.qmd). band : Grid An object that contains a connection the band within the dataset. For further information, see [Grid](/api/Grid.qmd)! - srs : osr.SpatialReference - Spatial reference (Projection) of the Grid object (e.g. WGS84). - Can be optained with the \ -[get_srs](/api/GridSource/get_srs.qmd) method. gtf : tuple The geotransform of a grid dataset. Can be optained via the [get_geotransform]\ (/api/GridSource/get_geotransform.qmd) method. Has the following shape: (left, xres, xrot, upper, yrot, yres). - ft : ogr.Feature - A Feature according to the \ -[ogr module](https://gdal.org/api/python/osgeo.ogr.html) of osgeo. - Can be optained by indexing a \ -[GeomSource](/api/GeomSource.qmd). + all_touched : bool, optional + Whether or not to include cells that are 'touched' without covering the + center of the cell. Returns ------- @@ -44,52 +146,54 @@ def clip( -------- - [clip_weighted](/api/overlay/clip_weighted.qmd) """ + # Get the geometry information form the feature geom = ft.GetGeometryRef() - + gtype = geom.GetGeometryType() + gself = geom + if gtype in [3, 6]: + gself = geom.GetGeometryRef(0) + + # Extract information + dx = gtf[1] + dy = gtf[5] minX, maxX, minY, maxY = geom.GetEnvelope() ulX, ulY = world2pixel(gtf, minX, maxY) lrX, lrY = world2pixel(gtf, maxX, minY) - c = pixel2world(gtf, ulX, ulY) - new_gtf = (c[0], gtf[1], 0.0, c[1], 0.0, gtf[-1]) + plX, plY = pixel2world(gtf, ulX, ulY) pxWidth = int(lrX - ulX) + 1 pxHeight = int(lrY - ulY) + 1 + # Create clip and mask clip = band[ulX, ulY, pxWidth, pxHeight] - # m = mask.ReadAsArray(ulX,ulY,pxWidth,pxHeight) - - # pts = geom.GetGeometryRef(0) - # pixels = [None] * pts.GetPointCount() - # for p in range(pts.GetPointCount()): - # pixels[p] = (world2Pixel(gtf, pts.GetX(p), pts.GetY(p))) + mask = ones((pxHeight, pxWidth)) - dr_r = gdal.GetDriverByName("MEM") - b_r = dr_r.Create("memset", pxWidth, pxHeight, 1, gdal.GDT_Int16) - b_r.SetSpatialRef(srs) - b_r.SetGeoTransform(new_gtf) + x = tile(arange(plX + 0.5 * dx, plX + (dx * pxWidth), dx), (pxHeight, 1)) + y = tile(arange(plY + 0.5 * dy, plY + (dy * pxHeight), dy), (pxWidth, 1)).T + # Create 3d arrays when all touched is true, to check for the corners + if all_touched: + x = stack([x - 0.5 * dx, x + 0.5 * dx, x + 0.5 * dx, x - 0.5 * dx]) + y = stack([y - 0.5 * dy, y - 0.5 * dy, y + 0.5 * dy, y + 0.5 * dy]) - dr_g = ogr.GetDriverByName("Memory") - src_g = dr_g.CreateDataSource("memdata") - lay_g = src_g.CreateLayer("mem", srs) - lay_g.CreateFeature(ft) + if gtype > 3: + geometry = gself.GetGeometryRef(0).GetPoints() + else: + geometry = gself.GetPoints() + mask = rasterize(x, y, geometry) - gdal.RasterizeLayer(b_r, [1], lay_g, None, None, [1], ["ALL_TOUCHED=TRUE"]) - clip = clip[b_r.ReadAsArray() == 1] + if all_touched: + mask = any(mask, axis=0) + # Get the vertex touched cells + vertices(mask, x[0], x[1], y[2], y[0], array(geometry)) - b_r = None - dr_r = None - lay_g = None - src_g = None - dr_g = None - - return clip + return clip[mask == 1] def clip_weighted( + ft: ogr.Feature, band: Grid, - srs: osr.SpatialReference, gtf: tuple, - ft: ogr.Feature, - upscale: int = 1, + all_touched: bool = False, + upscale: int = 3, ): """Clip a grid based on a feature (vector), but weighted. @@ -102,24 +206,23 @@ def clip_weighted( Parameters ---------- + ft : ogr.Feature + A Feature according to the \ +[ogr module](https://gdal.org/api/python/osgeo.ogr.html) of osgeo. + Can be optained by indexing a \ +[GeomSource](/api/GeomSource.qmd). band : Grid An object that contains a connection the band within the dataset. For further information, see [Grid](/api/Grid.qmd)! - srs : osr.SpatialReference - Spatial reference (Projection) of the Grid object (e.g. WGS84). - Can be optained with the \ -[get_srs](/api/GridSource/get_srs.qmd) method. gtf : tuple The geotransform of a grid dataset. Can be optained via the [get_geotransform]\ (/api/GridSource/get_geotransform.qmd) method. Has the following shape: (left, xres, xrot, upper, yrot, yres). - ft : ogr.Feature - A Feature according to the \ -[ogr module](https://gdal.org/api/python/osgeo.ogr.html) of osgeo. - Can be optained by indexing a \ -[GeomSource](/api/GeomSource.qmd). - upscale : int + all_touched : bool, optional + Whether or not to include cells that are 'touched' without covering the + center of the cell. + upscale : int, optional How much the underlying grid will be upscaled. The higher the value, the higher the accuracy. @@ -133,46 +236,53 @@ def clip_weighted( - [clip](/api/overlay/clip.qmd) """ geom = ft.GetGeometryRef() - + gtype = geom.GetGeometryType() + gself = geom + if gtype in [3, 6]: + gself = geom.GetGeometryRef(0) + + # Extract information + dx = gtf[1] + dy = gtf[5] minX, maxX, minY, maxY = geom.GetEnvelope() ulX, ulY = world2pixel(gtf, minX, maxY) lrX, lrY = world2pixel(gtf, maxX, minY) - c = pixel2world(gtf, ulX, ulY) - new_gtf = (c[0], gtf[1] / upscale, 0.0, c[1], 0.0, gtf[-1] / upscale) + plX, plY = pixel2world(gtf, ulX, ulY) + dxn = dx / upscale + dyn = dy / upscale pxWidth = int(lrX - ulX) + 1 pxHeight = int(lrY - ulY) + 1 + # Setup clip and mask arrays clip = band[ulX, ulY, pxWidth, pxHeight] - # m = mask.ReadAsArray(ulX,ulY,pxWidth,pxHeight) - - # pts = geom.GetGeometryRef(0) - # pixels = [None] * pts.GetPointCount() - # for p in range(pts.GetPointCount()): - # pixels[p] = (world2Pixel(gtf, pts.GetX(p), pts.GetY(p))) - - dr_r = gdal.GetDriverByName("MEM") - b_r = dr_r.Create( - "memset", pxWidth * upscale, pxHeight * upscale, 1, gdal.GDT_Int16 + x = tile( + arange(plX + 0.5 * dxn, plX + (dxn * (pxWidth * upscale)), dxn), + (pxHeight * upscale, 1), ) - b_r.SetSpatialRef(srs) - b_r.SetGeoTransform(new_gtf) + y = tile( + arange(plY + 0.5 * dyn, plY + (dyn * (pxHeight * upscale)), dyn), + (pxWidth * upscale, 1), + ).T + if all_touched: + x = stack([x - 0.5 * dxn, x + 0.5 * dxn, x + 0.5 * dxn, x - 0.5 * dxn]) + y = stack([y - 0.5 * dyn, y - 0.5 * dyn, y + 0.5 * dyn, y + 0.5 * dyn]) - dr_g = ogr.GetDriverByName("Memory") - src_g = dr_g.CreateDataSource("memdata") - lay_g = src_g.CreateLayer("mem", srs) - lay_g.CreateFeature(ft) + if gtype > 3: + geometry = gself.GetGeometryRef(0).GetPoints() + else: + geometry = gself.GetPoints() + mask = rasterize(x, y, geometry) - gdal.RasterizeLayer(b_r, [1], lay_g, None, None, [1], ["ALL_TOUCHED=TRUE"]) - _w = b_r.ReadAsArray().reshape((pxHeight, upscale, pxWidth, -1)).mean(3).mean(1) - clip = clip[_w != 0] + if all_touched: + mask = any(mask, axis=0) + # Get the vertex touched cells + vertices(mask, x[0], x[1], y[2], y[0], array(geometry)) - b_r = None - dr_r = None - lay_g = None - src_g = None - dr_g = None + # Resample the higher resolution mask + mask = mask.reshape((pxHeight, upscale, pxWidth, -1)).mean(3).mean(1) + clip = clip[mask != 0] - return clip, _w + return clip, mask def mask( @@ -183,14 +293,16 @@ def mask( def pin( + point: tuple, band: Grid, gtf: tuple, - point: tuple, -) -> array: +) -> ndarray: """Pin a the value of a cell based on a coordinate. Parameters ---------- + point : tuple + x and y coordinate. band : Grid Input object. This holds a connection to the specified band. gtf : tuple @@ -198,12 +310,10 @@ def pin( Can be optained via the [get_geotransform]\ (/api/GridSource/get_geotransform.qmd) method. Has the following shape: (left, xres, xrot, upper, yrot, yres). - point : tuple - x and y coordinate. Returns ------- - array + ndarray A NumPy array containing one value. """ x, y = world2pixel(gtf, *point) diff --git a/src/fiat/models/util.py b/src/fiat/models/util.py index 0f86581..0bf1e2a 100644 --- a/src/fiat/models/util.py +++ b/src/fiat/models/util.py @@ -22,11 +22,12 @@ def exposure_from_geom( ft: ogr.Feature, exp: TableLazy, oid: int, + mid: int, idxs_haz: list | tuple, pattern: object, ): """_summary_.""" - method = ft.GetField("extract_method") + method = ft.GetField(mid) haz = [ft.GetField(idx) for idx in idxs_haz] return ft, method, haz @@ -35,6 +36,7 @@ def exposure_from_csv( ft: ogr.Feature, exp: TableLazy, oid: int, + mid: int, idxs_haz: list | tuple, pattern: object, ): diff --git a/src/fiat/models/worker_geom.py b/src/fiat/models/worker_geom.py index a5a2423..f76fcd3 100644 --- a/src/fiat/models/worker_geom.py +++ b/src/fiat/models/worker_geom.py @@ -56,10 +56,12 @@ def worker( rp_coef = risk_density(cfg.get("hazard.return_periods")) rp_coef.reverse() + # Some exposure csv dependent data (or not) + mid = None pattern = None out_text_writer = DummyWriter() if exp_data is not None: - man_columns = [exp_data.columns.index(item) for item in man_columns] + man_columns_idxs = [exp_data.columns.index(item) for item in man_columns] pattern = regex_pattern(exp_data.delimiter) out_text_writer = BufferedTextWriter( Path(cfg.get("output.path"), cfg.get("output.csv.name")), @@ -83,6 +85,9 @@ def worker( total_idx = field_meta["total_idx"] types = field_meta["types"] idxs = field_meta["idxs"] + if exp_data is None: + man_columns_idxs = [gm.fields.index(item) for item in man_columns] + mid = gm.fields.index("extract_method") # Setup the dataset buffer writer out_geom = Path(cfg.get(f"output.geom.name{idx}")) @@ -100,7 +105,8 @@ def worker( ft, exp_data, oid, - man_columns, + mid, + man_columns_idxs, pattern, ) if info is None: @@ -115,10 +121,17 @@ def worker( for band, bn in bands: # How to get the hazard data if method == "area": - res = overlay.clip(band, haz.get_srs(), haz.get_geotransform(), ft) + res = overlay.clip( + ft, + band, + haz.get_geotransform(), + all_touched=True, + ) else: res = overlay.pin( - band, haz.get_geotransform(), geom.point_in_geom(ft) + geom.point_in_geom(ft), + band, + haz.get_geotransform(), ) res[res == band.nodata] = nan diff --git a/test/test_gis.py b/test/test_gis.py index 8a9908b..0ff8db4 100644 --- a/test/test_gis.py +++ b/test/test_gis.py @@ -7,10 +7,9 @@ def test_clip(geom_data, grid_event_data): ft = geom_data[4] hazard = overlay.clip( + ft, grid_event_data[1], - grid_event_data.get_srs(), grid_event_data.get_geotransform(), - ft, ) ft = None @@ -21,31 +20,22 @@ def test_clip(geom_data, grid_event_data): def test_clip_weighted(geom_data, grid_event_data): ft = geom_data[4] _, weights = overlay.clip_weighted( + ft, grid_event_data[1], - grid_event_data.get_srs(), grid_event_data.get_geotransform(), - ft, + all_touched=True, upscale=10, ) assert int(weights[0, 0] * 100) == 90 _, weights = overlay.clip_weighted( - grid_event_data[1], - grid_event_data.get_srs(), - grid_event_data.get_geotransform(), ft, - upscale=100, - ) - assert int(weights[0, 0] * 100) == 80 - - _, weights = overlay.clip_weighted( grid_event_data[1], - grid_event_data.get_srs(), grid_event_data.get_geotransform(), - ft, - upscale=1000, + all_touched=True, + upscale=100, ) - assert int(weights[0, 0] * 100) == 79 + assert int(weights[0, 0] * 100) == 81 def test_pin(geom_data, grid_event_data): From 402ec98398c60e293946d5ad061bf40a7e9dc24f Mon Sep 17 00:00:00 2001 From: Brendan Date: Fri, 1 Nov 2024 14:21:03 +0100 Subject: [PATCH 2/4] Added profiler --- src/fiat/cli/main.py | 14 ++++++++++++-- src/fiat/cli/util.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/fiat/cli/main.py b/src/fiat/cli/main.py index 3c089db..bdd78e6 100644 --- a/src/fiat/cli/main.py +++ b/src/fiat/cli/main.py @@ -7,7 +7,7 @@ from fiat.cfg import ConfigReader from fiat.cli.formatter import MainHelpFormatter -from fiat.cli.util import file_path_check, run_log +from fiat.cli.util import file_path_check, run_log, run_profiler from fiat.log import check_loglevel, setup_default_log from fiat.main import FIAT from fiat.version import __version__ @@ -68,7 +68,10 @@ def run(args): # Kickstart the model obj = FIAT(cfg) - run_log(obj.run, logger=logger) + if args.profile is not None: + run_profiler(obj.run, profile=args.profile, cfg=cfg, logger=logger) + else: + run_log(obj.run, logger=logger) ## Constructing the arguments parser for FIAT. @@ -125,6 +128,13 @@ def args_parser(): "config", help="Path to the settings file", ) + run_parser.add_argument( + "-p", + "--profile", + help=argparse.SUPPRESS, + action="store_const", + const="profile", + ) run_parser.add_argument( "-t", "--threads", diff --git a/src/fiat/cli/util.py b/src/fiat/cli/util.py index 068c467..9f7bdf6 100644 --- a/src/fiat/cli/util.py +++ b/src/fiat/cli/util.py @@ -1,9 +1,12 @@ """Util for cli.""" +import cProfile +import pstats import sys from pathlib import Path from typing import Callable +from fiat.cfg import ConfigReader from fiat.log import Log @@ -33,3 +36,32 @@ def run_log( logger.error(msg) # Exit with code 1 sys.exit(1) + + +def run_profiler( + func: Callable, + profile: str, + cfg: ConfigReader, + logger: Log, +): + """Run the profiler from cli.""" + logger.warning("Running profiler...") + + # Setup the profiler and run the function + profiler = cProfile.Profile() + profiler.enable() + run_log(func, logger=logger) + profiler.disable() + + # Save all the stats + profile_out = cfg.get("output.path") / profile + profiler.dump_stats(profile_out) + logger.info(f"Saved profiling stats to: {profile_out}") + + # Save a human readable portion to a text file + txt_out = cfg.get("output.path") / "profile.txt" + with open(txt_out, "w") as _w: + _w.write(f"Delft-FIAT profile ({cfg.filepath}):\n\n") + stats = pstats.Stats(profiler, stream=_w) + _ = stats.sort_stats("tottime").print_stats() + logger.info(f"Saved profiling stats in human readable format: {txt_out}") From 74278412ca6397ee127c7d4d2161dac0b44ea9c8 Mon Sep 17 00:00:00 2001 From: Brendan Date: Fri, 1 Nov 2024 14:26:18 +0100 Subject: [PATCH 3/4] Argument order --- test/test_gis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_gis.py b/test/test_gis.py index 0ff8db4..e6846e1 100644 --- a/test/test_gis.py +++ b/test/test_gis.py @@ -43,9 +43,9 @@ def test_pin(geom_data, grid_event_data): XY = geom.point_in_geom(ft) hazard = overlay.pin( + XY, grid_event_data[1], grid_event_data.get_geotransform(), - XY, ) assert int(round(hazard[0] * 100, 0)) == 160 From 5620c5ef7392ad27a46ecb5c307970c5cd54666a Mon Sep 17 00:00:00 2001 From: Brendan Date: Wed, 13 Nov 2024 18:23:00 +0100 Subject: [PATCH 4/4] Update changelog; update cli help --- docs/changelog.qmd | 3 +++ src/fiat/cli/main.py | 14 +++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/changelog.qmd b/docs/changelog.qmd index 79ac3db..3b9ab33 100644 --- a/docs/changelog.qmd +++ b/docs/changelog.qmd @@ -6,10 +6,13 @@ title: "What's new?" These are the unreleased changes of Delft-FIAT. ### Added +- Custom rasterization functions instead of relying on GDAL (which included dataset creation at runtime) +- Profiler for developers ### Changed - Disabled locks when running 'single threaded' - Fixed logging of errors during settings file checks +- Improved performance when running without csv - Logging class `Log` is now called `Logger` - Specifying destination ('dst') is now optional for `setup_default_log` diff --git a/src/fiat/cli/main.py b/src/fiat/cli/main.py index 91c3a14..49cdfab 100644 --- a/src/fiat/cli/main.py +++ b/src/fiat/cli/main.py @@ -137,13 +137,6 @@ def args_parser(): "config", help="Path to the settings file", ) - run_parser.add_argument( - "-p", - "--profile", - help=argparse.SUPPRESS, - action="store_const", - const="profile", - ) run_parser.add_argument( "-t", "--threads", @@ -167,6 +160,13 @@ def args_parser(): action="count", default=0, ) + run_parser.add_argument( + "-p", + "--profile", + help="Run profiler", + action="store_const", + const="profile", + ) run_parser.set_defaults(func=run) return parser