diff --git a/dam/processing/warp.py b/dam/processing/warp.py index 57499ac..df204f8 100644 --- a/dam/processing/warp.py +++ b/dam/processing/warp.py @@ -1,9 +1,12 @@ import xarray as xr +import rioxarray as rxr import rasterio import dask from typing import Optional import numpy as np +import tempfile +import os from ..utils.register_process import as_DAM_process @@ -31,8 +34,22 @@ def match_grid(input: xr.DataArray, resampling_method: str|int = 'NearestNeighbour', nodata_threshold: float = 1, nodata_value: Optional[float] = None, + engine = 'xarray' ) -> xr.DataArray: + if engine == 'xarray': + return _match_grid_xarray(input, grid, resampling_method, nodata_threshold, nodata_value) + elif engine == 'gdal': # this is for compatibility with a previous version. It is not recommended to use gdal + return _match_grid_gdal(input, grid, resampling_method, nodata_threshold, nodata_value) + else: + raise ValueError('engine must be one of [xarray, gdal]') + +def _match_grid_xarray(input: xr.DataArray, + grid: xr.DataArray, + resampling_method: str|int, + nodata_threshold: float, + nodata_value: Optional[float] + ) -> xr.DataArray: input_da = input mask_da = grid @@ -73,4 +90,83 @@ def process_chunk(chunk, nodata_value): nan_mask = nan_mask * 100 regridded_mask = nan_mask.rio.reproject_match(mask_da, resampling=rasterio.enums.Resampling(5)) - return input_reprojected.where(regridded_mask < nodata_threshold*100, nodata_value) \ No newline at end of file + return input_reprojected.where(regridded_mask < nodata_threshold*100, nodata_value) + +def _match_grid_gdal(input: xr.DataArray, + grid: xr.DataArray, + resampling_method: str|int, + nodata_threshold: float, + nodata_value: Optional[float]) -> xr.DataArray: + + from osgeo import gdal, gdalconst + + _resampling_methods_gdal = ['NearestNeighbour', 'Bilinear', + 'Cubic', 'CubicSpline', + 'Lanczos', + 'Average', 'Mode', + 'Max', 'Min', + 'Med', 'Q1', 'Q3'] + + if isinstance(resampling_method, int): + for method in _resampling_methods: + if _resampling_methods[method] == resampling_method: + resampling_method = method + break + + for method in _resampling_methods_gdal: + if method.lower() == resampling_method.lower(): + resampling_method = method + break + else: + raise ValueError(f'resampling_method must be one of {_resampling_methods_gdal}') + + # save input and grid to temporary files + with tempfile.TemporaryDirectory() as tmpdir: + input_path = f'{tmpdir}/input.tif' + grid_path = f'{tmpdir}/grid.tif' + input.rio.to_raster(input_path) + grid.rio.to_raster(grid_path) + + output_path = input_path.replace('.tif', '_regridded.tif') + + # Open the input and reference raster files + input_ds = gdal.Open(input_path, gdalconst.GA_ReadOnly) + input_transform = input_ds.GetGeoTransform() + input_projection = input_ds.GetProjection() + + if nodata_value is not None: + input_ds.GetRasterBand(1).SetNoDataValue(nodata_value) + + in_type = input_ds.GetRasterBand(1).DataType + input_ds = None + + # Open the reference raster file + input_grid = gdal.Open(grid_path, gdalconst.GA_ReadOnly) + grid_transform = input_grid.GetGeoTransform() + grid_projection = input_grid.GetProjection() + + # Get the resampling method + resampling = getattr(gdalconst, f'GRA_{resampling_method}') + + # get the output bounds = the grid bounds + # input_bounds = [input_transform[0], input_transform[3], input_transform[0] + input_transform[1] * input_ds.RasterXSize, + # input_transform[3] + input_transform[5] * input_ds.RasterYSize] + output_bounds = [grid_transform[0], grid_transform[3], grid_transform[0] + grid_transform[1] * input_grid.RasterXSize, + grid_transform[3] + grid_transform[5] * input_grid.RasterYSize] + + # set the type of the output to the type of the input if resampling is nearest neighbour, otherwise to float32 + if resampling == gdalconst.GRA_NearestNeighbour: + output_type = in_type + else: + output_type = gdalconst.GDT_Float32 + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + gdal.Warp(output_path, input_path, outputBounds=output_bounds, #outputBoundsSRS = input_projection, + srcSRS=input_projection, dstSRS=grid_projection, + xRes=grid_transform[1], yRes=grid_transform[5], resampleAlg=resampling, + outputType=output_type, + format='GTiff', creationOptions=['COMPRESS=LZW'], multithread=True) + + output = rxr.open_rasterio(output_path) + + return output \ No newline at end of file diff --git a/dam/tools b/dam/tools index 8fe910b..7682c45 160000 --- a/dam/tools +++ b/dam/tools @@ -1 +1 @@ -Subproject commit 8fe910b07513c59c87f03694ce083f33da97009d +Subproject commit 7682c45a725a86eb9201199a86ac2ed593b2bcec diff --git a/dam/workflow/workflow.py b/dam/workflow/workflow.py index c076cb3..fb3377e 100644 --- a/dam/workflow/workflow.py +++ b/dam/workflow/workflow.py @@ -44,12 +44,13 @@ def __init__(self, self.tmp_dir = tempfile.mkdtemp(dir = tmp_dir) @classmethod - def from_options(cls, options: Options) -> 'DAMWorkflow': + def from_options(cls, options: Options|dict) -> 'DAMWorkflow': + if isinstance(options, dict): options = Options(options) input = options.get('input',ignore_case=True) - if isinstance(input, Options): + if isinstance(input, dict): input = Dataset.from_options(input) output = options.get('output', None, ignore_case=True) - if isinstance(output, Options): + if isinstance(output, dict): output = Dataset.from_options(output) wf_options = options.get('options', None, ignore_case=True) @@ -70,23 +71,33 @@ def clean_up(self): except Exception as e: print(f'Error cleaning up temporary directory: {e}') - def make_output(self, input: Dataset, output: Optional[Dataset|dict] = None) -> Dataset: + def make_output(self, input: Dataset, output: Optional[Dataset|dict] = None, name = None) -> Dataset: if isinstance(output, Dataset): return output + input_pattern = input.key_pattern + input_name = input.name + if name is not None: + extention = os.path.splitext(input_pattern)[1] + output_pattern = input_pattern.replace(f'{extention}', f'_{name}{extention}') + output_name = f'{input_name}_{name}' + if output is None: - key_pattern = input.key_pattern + output_pattern = output_pattern elif isinstance(output, dict): - key_pattern = output.get('key_pattern', input.key_pattern) + output_pattern = output.get('key_pattern', output_pattern) else: raise ValueError('Output must be a Dataset or a dictionary.') output_type = self.options['intermediate_output'] if output_type == 'Mem': - return MemoryDataset(key_pattern = input.key_pattern) + output_ds = MemoryDataset(key_pattern = output_pattern) elif output_type == 'Tmp': - filename = os.path.basename(key_pattern) - return LocalDataset(path = self.tmp_dir, filename = filename) + filename = os.path.basename(output_pattern) + output_ds = LocalDataset(path = self.tmp_dir, filename = filename) + + output_ds.name = output_name + return output_ds def add_process(self, function, output: Optional[Dataset|dict] = None, **kwargs) -> None: if len(self.processes) == 0: @@ -96,8 +107,7 @@ def add_process(self, function, output: Optional[Dataset|dict] = None, **kwargs) previous = self.processes[-1] this_input = previous.output - this_output = self.make_output(this_input, output) - this_output.name = f'{this_output.name}_{function.__name__}' + this_output = self.make_output(this_input, output, function.__name__) this_process = DAMProcessor(function = function, input = this_input, args = kwargs, diff --git a/setup.py b/setup.py index ccff7da..91227a2 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='dam', - version='1.1.1', + version='1.1.0-alpha', packages=find_packages(), description='A package for raster data processing developed at the CIMA Research Foundation', author='Luca Trotter',